Skip to content

Commit

Permalink
Add validation check for training parameters in engine method abstrac…
Browse files Browse the repository at this point in the history
…tion

Signed-off-by: AnnTian Shao <[email protected]>
  • Loading branch information
AnnTian Shao committed Jan 24, 2025
1 parent 585b373 commit e9d4807
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Abstract class for KNN methods. This class provides the common functionality for all KNN methods.
Expand Down Expand Up @@ -108,6 +114,55 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
return PerDimensionProcessor.NOOP_PROCESSOR;
}

protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
};
}

protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
}
Expand All @@ -131,6 +186,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType()))
.trainingConfigValidationSetup(doGetTrainingConfigValidationSetup())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Map;
import java.util.function.Function;

/**
* Context a library gives to build one of its indices
Expand Down Expand Up @@ -49,6 +50,12 @@ public interface KNNLibraryIndexingContext {
*/
PerDimensionProcessor getPerDimensionProcessor();

/**
*
* @return Get function that validates training model parameters
*/
Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> getTrainingConfigValidationSetup();

/**
* Get the vector transformer that will be used to transform the vector before indexing.
* This will be applied at vector level once entire vector is parsed and validated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.Collections;
import java.util.Map;
import java.util.function.Function;

/**
* Simple implementation of {@link KNNLibraryIndexingContext}
Expand All @@ -29,6 +30,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY;
private Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> trainingConfigValidationSetup;

@Override
public Map<String, Object> getLibraryParameters() {
Expand Down Expand Up @@ -59,4 +61,9 @@ public PerDimensionValidator getPerDimensionValidator() {
public PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
public Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> getTrainingConfigValidationSetup() {
return trainingConfigValidationSetup;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* This object provides the input of the validation checks for training model inputs.
* The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@AllArgsConstructor
public class TrainingConfigValidationInput {
private Long trainingVectorsCount;
private KNNMethodContext knnMethodContext;
private KNNMethodConfigContext knnMethodConfigContext;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* This object provides the output of the validation checks for training model inputs.
* The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@AllArgsConstructor
public class TrainingConfigValidationOutput {
private boolean valid;
private long minTrainingVectorCount;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
import org.opensearch.common.ValidationException;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportService;

import java.util.Map;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER;

/**
Expand Down Expand Up @@ -138,26 +140,25 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques
trainingVectors = trainingModelRequest.getMaximumVectorCount();
}

long minTrainingVectorCount = 1000;
MethodComponentContext encoderContext = (MethodComponentContext) trainingModelRequest.getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (trainingModelRequest.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) trainingModelRequest.getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}
KNNMethodContext knnMethodContext = trainingModelRequest.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingModelRequest.getKnnMethodConfigContext();

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);

Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> validateTrainingConfig = knnLibraryIndexingContext
.getTrainingConfigValidationSetup();

if (trainingVectors < minTrainingVectorCount) {
TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();

TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build()
);
if (!validation.isValid()) {
ValidationException exception = new ValidationException();
exception.addValidationError("Number of training points should be greater than " + minTrainingVectorCount);
exception.addValidationError(
String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount())
);
listener.onFailure(exception);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
import org.opensearch.knn.index.engine.EngineResolver;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import java.util.function.Function;

/**
* Request to train and serialize a model
Expand Down Expand Up @@ -285,11 +287,17 @@ public ActionRequestValidationException validate() {
exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters");
}

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> validateTrainingConfig = knnLibraryIndexingContext
.getTrainingConfigValidationSetup();
TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();
TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.knnMethodConfigContext(knnMethodConfigContext).knnMethodContext(knnMethodContext).build()
);

// Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
if (!validation.isValid()) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
}
Expand Down

0 comments on commit e9d4807

Please sign in to comment.