Skip to content

Commit

Permalink
fixes to encoder and method classes
Browse files Browse the repository at this point in the history
Signed-off-by: AnnTian Shao <[email protected]>
  • Loading branch information
AnnTian Shao committed Jan 30, 2025
1 parent 9887948 commit 06ecaf9
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,29 +183,4 @@ protected void validateCompressionConflicts(CompressionLevel originalCompression
}
}

protected void validateMDivisibleByVectorDimension(
KNNMethodContext resolvedKnnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
Map<String, Encoder> encoderMap
) {
if (isEncoderSpecified(resolvedKnnMethodContext) == false) {
return;
}
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext));
if (encoder == null) {
return;
}

TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();

TrainingConfigValidationOutput validationOutput = encoder.validateEncoderConfig(
inputBuilder.knnMethodContext(resolvedKnnMethodContext).knnMethodConfigContext(knnMethodConfigContext).build()
);

if (!validationOutput.isValid()) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
throw validationException;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@

import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQClipToFP16RangeEnabled;
import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQfp16;
Expand Down Expand Up @@ -147,40 +144,4 @@ protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.getVectorTransformer(KNNEngine.FAISS, spaceType);
}

@Override
protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.KNNMethodContext;

/**
* Flat faiss encoder. Flat encoding means that it does nothing. It needs an encoder, though, because it
Expand Down Expand Up @@ -60,23 +58,7 @@ public CompressionLevel calculateCompressionLevel(

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION;
Expand All @@ -30,6 +34,8 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Faiss HNSW method implementation
Expand Down Expand Up @@ -124,4 +130,40 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter()
SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent))
);
}

@Override
protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
Expand All @@ -33,6 +37,7 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_DEFAULT;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES_LIMIT;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Faiss ivf implementation
Expand Down Expand Up @@ -150,4 +155,40 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter()
SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent))
);
}

@Override
protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.ResolvedMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

Expand Down Expand Up @@ -73,8 +75,8 @@ public ResolvedMethodContext resolveMethod(
encoderMap
);

// Validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
validateMDivisibleByVectorDimension(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap);
// Validate encoder parameters
validateEncoderConfig(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap);

// Validate that resolved compression doesnt have any conflicts
validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel);
Expand Down Expand Up @@ -151,6 +153,32 @@ private void validateConfig(KNNMethodConfigContext knnMethodConfigContext) {
}
}

protected void validateEncoderConfig(
KNNMethodContext resolvedKnnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
Map<String, Encoder> encoderMap
) {
if (isEncoderSpecified(resolvedKnnMethodContext) == false) {
return;
}
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext));
if (encoder == null) {
return;
}

TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();

TrainingConfigValidationOutput validationOutput = encoder.validateEncoderConfig(
inputBuilder.knnMethodContext(resolvedKnnMethodContext).knnMethodConfigContext(knnMethodConfigContext).build()
);

if (!validationOutput.isValid()) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
throw validationException;
}
}

private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) {
if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) {
return knnMethodConfigContext.getCompressionLevel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.KNNMethodContext;

import java.util.Objects;
import java.util.Set;
Expand All @@ -26,7 +25,6 @@
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_TYPES;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;

/**
* Faiss SQ encoder
Expand Down Expand Up @@ -68,23 +66,7 @@ public CompressionLevel calculateCompressionLevel(

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.KNNMethodContext;

import java.util.List;
import java.util.Set;
Expand All @@ -27,7 +26,6 @@
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.MAXIMUM_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.MINIMUM_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;

/**
* Lucene scalar quantization encoder
Expand Down Expand Up @@ -68,23 +66,7 @@ public CompressionLevel calculateCompressionLevel(

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
}
}

0 comments on commit 06ecaf9

Please sign in to comment.