Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more detailed error messages for KNN model training #2378

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Added more detailed error messages for KNN model training (#2378)[https://github.com/opensearch-project/k-NN/pull/2378]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public void testIVFSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Ex

// Add training data
createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, DIMENSION);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, DIMENSION);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -278,7 +278,7 @@ public void testIVFSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Ra

// Add training data
createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Abstract class for KNN methods. This class provides the common functionality for all KNN methods.
Expand Down Expand Up @@ -108,6 +114,55 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
return PerDimensionProcessor.NOOP_PROCESSOR;
}

protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in the right direction, but ideally we want this to be handled per algo type. For instance, pq validation should happen in https://github.com/opensearch-project/k-NN/blob/main/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java. IVF validation should happen in https://github.com/opensearch-project/k-NN/blob/main/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java, etc.

If you think this change wouldnt be possible to make in time for code freeze (Monday), I think we shouldnt block this PR. However, this last refactor should be taken up as a followup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, will merge this PR and create a separate PR for followup. Thanks

return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
};
}

protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
}
Expand All @@ -131,6 +186,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType()))
.trainingConfigValidationSetup(doGetTrainingConfigValidationSetup())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Map;
import java.util.function.Function;

/**
* Context a library gives to build one of its indices
Expand Down Expand Up @@ -49,6 +50,12 @@ public interface KNNLibraryIndexingContext {
*/
PerDimensionProcessor getPerDimensionProcessor();

/**
*
* @return Get function that validates training model parameters
*/
Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> getTrainingConfigValidationSetup();

/**
* Get the vector transformer that will be used to transform the vector before indexing.
* This will be applied at vector level once entire vector is parsed and validated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.Collections;
import java.util.Map;
import java.util.function.Function;

/**
* Simple implementation of {@link KNNLibraryIndexingContext}
Expand All @@ -29,6 +30,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY;
private Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> trainingConfigValidationSetup;

@Override
public Map<String, Object> getLibraryParameters() {
Expand Down Expand Up @@ -59,4 +61,9 @@ public PerDimensionValidator getPerDimensionValidator() {
public PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
public Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> getTrainingConfigValidationSetup() {
return trainingConfigValidationSetup;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* This object provides the input of the validation checks for training model inputs.
* The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@AllArgsConstructor
public class TrainingConfigValidationInput {
private Long trainingVectorsCount;
private KNNMethodContext knnMethodContext;
private KNNMethodConfigContext knnMethodConfigContext;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* This object provides the output of the validation checks for training model inputs.
* The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@AllArgsConstructor
public class TrainingConfigValidationOutput {
private boolean valid;
private long minTrainingVectorCount;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
import org.opensearch.common.ValidationException;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportService;

import java.util.Map;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER;
Expand Down Expand Up @@ -134,6 +140,29 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques
trainingVectors = trainingModelRequest.getMaximumVectorCount();
}

KNNMethodContext knnMethodContext = trainingModelRequest.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingModelRequest.getKnnMethodConfigContext();

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);

Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> validateTrainingConfig = knnLibraryIndexingContext
.getTrainingConfigValidationSetup();

TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();

TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build()
);
if (!validation.isValid()) {
ValidationException exception = new ValidationException();
exception.addValidationError(
String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount())
);
listener.onFailure(exception);
return;
}

listener.onResponse(
estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType())
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
import org.opensearch.knn.index.engine.EngineResolver;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.function.Function;

/**
* Request to train and serialize a model
Expand Down Expand Up @@ -283,6 +287,21 @@ public ActionRequestValidationException validate() {
exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters");
}

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> validateTrainingConfig = knnLibraryIndexingContext
.getTrainingConfigValidationSetup();
TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();
TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.knnMethodConfigContext(knnMethodConfigContext).knnMethodContext(knnMethodContext).build()
);

// Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (!validation.isValid()) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
}

// Validate training index exists
IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex);
if (indexMetadata == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ public void run() {
} catch (Exception e) {
logger.error("Failed to run training job for model \"" + modelId + "\": ", e);
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(
"Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training."
);
modelMetadata.setError("Failed to execute training. " + e.getMessage());

KNNCounter.TRAINING_ERRORS.increment();

Expand Down
16 changes: 8 additions & 8 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because that's the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -468,7 +468,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -736,7 +736,7 @@ public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() {

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -960,7 +960,7 @@ public void testIVFSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() {

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -1064,7 +1064,7 @@ public void testIVFSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenS

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -1144,7 +1144,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed(

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -1412,7 +1412,7 @@ public void testKNNQuery_withModelDifferentCombination_thenSuccess() throws Exce

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

// Call train API - IVF with nlists = 1 is brute force, but will require training
Expand Down Expand Up @@ -1767,7 +1767,7 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() {

createKnnIndex(trainingIndexName, trainIndexMapping);

int trainingDataCount = 40;
int trainingDataCount = 1100;
bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() {
.toString();
createKnnIndex(INDEX_NAME, trainIndexMapping);

int trainingDataCount = 100;
int trainingDataCount = 1100;
bulkIngestRandomByteVectors(INDEX_NAME, FIELD_NAME, trainingDataCount, dimension);

XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception {
int dimensions = randomIntBetween(2, 10);
String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions);
createKnnIndex(TRAIN_INDEX_PARAMETER, trainMapping);
bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions * 3, dimensions);
bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, 1100, dimensions);

XContentBuilder methodBuilder = XContentFactory.jsonBuilder()
.startObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class ModeAndCompressionIT extends KNNRestTestCase {

private static final String TRAINING_INDEX_NAME = "training_index";
private static final String TRAINING_FIELD_NAME = "training_field";
private static final int TRAINING_VECS = 20;
private static final int TRAINING_VECS = 1100;

private static final int DIMENSION = 16;
private static final int NUM_DOCS = 20;
Expand Down
Loading
Loading