Skip to content

Commit

Permalink
fixes and moved all validation to encoder
Browse files Browse the repository at this point in the history
Signed-off-by: AnnTian Shao <[email protected]>
  • Loading branch information
AnnTian Shao committed Feb 10, 2025
1 parent 697953f commit 2cc1523
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 111 deletions.
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/index/engine/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ default String getName() {
*
* @return Validation output of encoder parameters
*/
TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput validationInput);
default TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput validationInput) {
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,33 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m
return (MethodComponentContext) object;
}

protected String getEncoderName(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext);
if (methodComponentContext == null) {
return null;
}

return methodComponentContext.getName();
}

protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER);
}

protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) {
return knnMethodContext != null
&& knnMethodContext.getMethodComponentContext().getParameters() != null
&& knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER);
}

@Override
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
// While FAISS doesn't directly support cosine similarity, we can leverage the mathematical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
Expand Down Expand Up @@ -98,6 +99,7 @@ public CompressionLevel calculateCompressionLevel(
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

Expand All @@ -113,6 +115,29 @@ public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValida
builder.valid(true);
}
}

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.pow(2, code_size);
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}

return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

import java.util.Set;

import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;

/**
* Flat faiss encoder. Flat encoding means that it does nothing. It needs an encoder, though, because it
* is used in generating the index description.
Expand Down Expand Up @@ -55,10 +52,4 @@ public CompressionLevel calculateCompressionLevel(
) {
return CompressionLevel.x1;
}

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Faiss HNSW method implementation
Expand Down Expand Up @@ -136,34 +134,17 @@ protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
if (isEncoderSpecified(knnMethodContext) == false) {
return builder.build();
}
Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext));
if (encoder == null) {
return builder.build();
}
return builder.build();

return encoder.validateEncoderConfig(trainingConfigValidationInput);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_DEFAULT;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_LIMIT;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Faiss ivf implementation
Expand Down Expand Up @@ -161,34 +160,17 @@ protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
if (isEncoderSpecified(knnMethodContext) == false) {
return builder.build();
}
return builder.build();
Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext));
if (encoder == null) {
return builder.build();
}

return encoder.validateEncoderConfig(trainingConfigValidationInput);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;

import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -63,10 +61,4 @@ public CompressionLevel calculateCompressionLevel(
// TODO: Hard code for now
return CompressionLevel.x2;
}

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.KNNMethodContext;

import java.util.HashMap;
import java.util.Locale;
import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.FAISS_FLAT_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;

/**
* Quantization framework binary encoder,
Expand Down Expand Up @@ -114,26 +110,4 @@ public CompressionLevel calculateCompressionLevel(
// Validation will ensure that only 1 of the supported bit count will be selected.
return CompressionLevel.x8;
}

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;

import java.util.List;
import java.util.Set;
Expand Down Expand Up @@ -63,10 +61,4 @@ public CompressionLevel calculateCompressionLevel(
// Hard coding to 4x for now, given thats all that is supported.
return CompressionLevel.x4;
}

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ public CompressionLevel calculateCompressionLevel(
) {
return DEFAULT_COMPRESSION;
}

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
return builder.build();
}
};

private final static Map<String, Encoder> ENCODER_MAP = Map.of(ENCODER_NAME, TEST_ENCODER);
Expand Down

0 comments on commit 2cc1523

Please sign in to comment.