From 015b30b4dc0f7d372c9c1613e823b374ff609741 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Fri, 9 Feb 2024 14:03:39 -0800 Subject: [PATCH] Use valid data for sqfp16 test Signed-off-by: Heemin Kim --- .../opensearch/knn/jni/JNIServiceTests.java | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index c2470ea47..7673ee463 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -525,7 +525,7 @@ public void testQueryIndex_faiss_sqfp16_valid() { Path tmpFile = createTempFile(); JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + truncateToFp16Range(testData.indexData.vectors), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), FAISS_NAME @@ -547,6 +547,23 @@ public void testQueryIndex_faiss_sqfp16_valid() { } } + // If the value is outside of the fp16 range, then convert it to the fp16 minimum or maximum value + private float[][] truncateToFp16Range(final float[][] data) { + float[][] result = new float[data.length][data[0].length]; + for (int i = 0; i < data.length; i++) { + for (int j = 0; j < data[i].length; j++) { + float value = data[i][j]; + if (value < Float.MIN_VALUE || value > Float.MAX_VALUE) { + // If value is outside of the range, set it to the maximum or minimum value + result[i][j] = value < 0 ? -Float.MAX_VALUE : Float.MAX_VALUE; + } else { + result[i][j] = value; + } + } + } + return result; + } + @SneakyThrows public void testTrain_whenConfigurationIsIVFSQFP16_thenSucceed() { long trainPointer = transferVectors(10);