Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update training validation to be handled per algo type #2462

Open
wants to merge 4 commits into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16R
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);

int dimension = 2;
int dimension = 128;

// Create an index
/**
Expand Down Expand Up @@ -198,16 +198,35 @@ public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16R

createKnnIndex(testIndex, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(testIndex)));
Float[] vector1 = { -65523.76f, 65504.2f };
Float[] vector2 = { -270.85f, 65514.2f };
Float[] vector3 = { -150.9f, 65504.0f };
Float[] vector4 = { -20.89f, 100000000.0f };

Float[] vector1 = new Float[dimension];
Float[] vector2 = new Float[dimension];
Float[] vector3 = new Float[dimension];
Float[] vector4 = new Float[dimension];
float[] queryVector = new float[dimension];
int halfDimension = dimension / 2;

for (int i = 0; i < dimension; i++) {
if (i < halfDimension) {
vector1[i] = -65523.76f;
vector2[i] = -270.85f;
vector3[i] = -150.9f;
vector4[i] = -20.89f;
queryVector[i] = -10.5f;
} else {
vector1[i] = 65504.2f;
vector2[i] = 65514.2f;
vector3[i] = 65504.0f;
vector4[i] = 100000000.0f;
queryVector[i] = 25.48f;
}
}

addKnnDoc(testIndex, "1", TEST_FIELD, vector1);
addKnnDoc(testIndex, "2", TEST_FIELD, vector2);
addKnnDoc(testIndex, "3", TEST_FIELD, vector3);
addKnnDoc(testIndex, "4", TEST_FIELD, vector4);

float[] queryVector = { -10.5f, 25.48f };
int k = 4;
Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, k), k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), TEST_FIELD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@
import java.util.Set;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Abstract class for KNN methods. This class provides the common functionality for all KNN methods.
* It defines the common attributes and methods that all KNN methods should implement.
Expand Down Expand Up @@ -116,49 +111,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(

protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,5 @@ protected void validateCompressionConflicts(CompressionLevel originalCompression
throw validationException;
}
}

}
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,14 @@ default String getName() {
* return {@link CompressionLevel#NOT_CONFIGURED}
*/
CompressionLevel calculateCompressionLevel(MethodComponentContext encoderContext, KNNMethodConfigContext knnMethodConfigContext);

/**
* Validates config of encoder
*
* @return Validation output of encoder parameters
*/
default TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput validationInput) {
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
@Builder
@AllArgsConstructor
public class TrainingConfigValidationOutput {
private boolean valid;
private long minTrainingVectorCount;
private Boolean valid;
private Long minTrainingVectorCount;
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,33 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m
return (MethodComponentContext) object;
}

protected String getEncoderName(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext);
if (methodComponentContext == null) {
return null;
}

return methodComponentContext.getName();
}

protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER);
}

protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) {
return knnMethodContext != null
&& knnMethodContext.getMethodComponentContext().getParameters() != null
&& knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER);
}

@Override
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
// While FAISS doesn't directly support cosine similarity, we can leverage the mathematical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.KNNMethodContext;

/**
* Abstract class for Faiss PQ encoders. This class provides the common logic for product quantization based encoders
Expand Down Expand Up @@ -89,4 +94,50 @@ public CompressionLevel calculateCompressionLevel(
// compression
return CompressionLevel.MAX_COMPRESSION_LEVEL;
}

@Override
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.pow(2, code_size);
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}

return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION;
Expand Down Expand Up @@ -124,4 +128,23 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter()
SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent))
);
}

@Override
protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

if (isEncoderSpecified(knnMethodContext) == false) {
return builder.build();
}
Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext));
if (encoder == null) {
return builder.build();
}

return encoder.validateEncoderConfig(trainingConfigValidationInput);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
Expand Down Expand Up @@ -150,4 +154,23 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter()
SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent))
);
}

@Override
protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

if (isEncoderSpecified(knnMethodContext) == false) {
return builder.build();
}
Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext));
if (encoder == null) {
return builder.build();
}

return encoder.validateEncoderConfig(trainingConfigValidationInput);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.ResolvedMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

Expand Down Expand Up @@ -73,6 +75,9 @@ public ResolvedMethodContext resolveMethod(
encoderMap
);

// Validate encoder parameters
validateEncoderConfig(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap);

// Validate that resolved compression doesnt have any conflicts
validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel);
knnMethodConfigContext.setCompressionLevel(resolvedCompressionLevel);
Expand Down Expand Up @@ -148,6 +153,32 @@ private void validateConfig(KNNMethodConfigContext knnMethodConfigContext) {
}
}

protected void validateEncoderConfig(
KNNMethodContext resolvedKnnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
Map<String, Encoder> encoderMap
) {
if (isEncoderSpecified(resolvedKnnMethodContext) == false) {
return;
}
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext));
if (encoder == null) {
return;
}

TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();

TrainingConfigValidationOutput validationOutput = encoder.validateEncoderConfig(
inputBuilder.knnMethodContext(resolvedKnnMethodContext).knnMethodConfigContext(knnMethodConfigContext).build()
);

if (validationOutput.getValid() != null && !validationOutput.getValid()) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
throw validationException;
}
}

private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) {
if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) {
return knnMethodConfigContext.getCompressionLevel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques
TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build()
);
if (!validation.isValid()) {
if (validation.getValid() != null && !validation.getValid()) {
ValidationException exception = new ValidationException();
exception.addValidationError(
String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount())
Expand Down
Loading
Loading