Skip to content

Commit

Permalink
Add integration test
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan committed Apr 25, 2024
1 parent 613743a commit 708ef61
Showing 1 changed file with 44 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.containsString;
import static org.opensearch.knn.common.KNNConstants.*;

public class KNNScriptScoringIT extends KNNRestTestCase {

private static final String TEST_MODEL = "test-model";

public void testKNNL2ScriptScore() throws Exception {
testKNNScriptScore(SpaceType.L2);
}
Expand Down Expand Up @@ -550,6 +553,47 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception {
assertEquals(1, secondQueryCacheMap.get("hit_count"));
}

public void testKNNScriptScoreOnModelBasedIndex() throws Exception {
int dimensions = randomIntBetween(2, 10);
String modelName = TEST_MODEL;
String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions);
createKnnIndex(TRAIN_INDEX_PARAMETER, trainMapping);
bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions * 3, dimensions);

XContentBuilder methodBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, METHOD_IVF)
.field(KNN_ENGINE, FAISS_NAME)
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 4)
.field(METHOD_PARAMETER_NPROBES, 2)
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(methodBuilder);

trainModel(modelName, TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions, method, "test model for script score");
assertTrainingSucceeds(modelName, 30, 1000);

String testMapping = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD)
.startObject(FIELD_NAME)
.field(TYPE, TYPE_KNN_VECTOR)
.field(MODEL_ID, modelName)
.endObject()
.endObject()
.endObject()
.toString();

for (SpaceType spaceType : SpaceType.values()) {
if (spaceType != SpaceType.HAMMING_BIT) {
final float[] queryVector = randomVector(dimensions);
final BiFunction<float[], float[], Float> scoreFunction = getScoreFunction(spaceType, queryVector);
createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector);
}
}
}

private List<String> createMappers(int dimensions) throws Exception {
return List.of(
createKnnIndexMapping(FIELD_NAME, dimensions),
Expand Down

0 comments on commit 708ef61

Please sign in to comment.