From 7933bb55e5689f2be5005fadebb898807b4942fc Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Fri, 24 May 2024 09:00:14 -0700 Subject: [PATCH] Use the Lucene Distance Calculation Function in Script Scoring for doing exact search (#1699) * Use the Lucene Distance Calculation Function in Script Scoring for doing exact search Signed-off-by: Ryan Bogan * Add Changelog entry Signed-off-by: Ryan Bogan * Fix failing test Signed-off-by: Ryan Bogan * fix test Signed-off-by: Ryan Bogan * Fix test bug and remove unnecessary validation Signed-off-by: Ryan Bogan * Remove cosineSimilOptimized Signed-off-by: Ryan Bogan * Revert "Remove cosineSimilOptimized" This reverts commit f872d8389683186c9ff64f6a65fd77f170f4a47d. Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan (cherry picked from commit 7a88f40dd084cd4c5d9cc5c1ef3d8b26fd25d422) --- CHANGELOG.md | 1 + .../knn/plugin/script/KNNScoringUtil.java | 32 ++++--------------- .../plugin/script/KNNScoringSpaceTests.java | 4 +-- 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c837b493..d6f9739f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.14...2.x) ### Features +* Use the Lucene Distance Calculation Function in Script Scoring for doing exact search [#1699](https://github.com/opensearch-project/k-NN/pull/1699) ### Enhancements * Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688) * Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684) diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 114499100..84e986faa 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -10,6 +10,7 @@ import java.util.Objects; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.VectorUtil; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -48,13 +49,7 @@ private static void requireEqualDimension(final float[] queryVector, final float * @return L2 score */ public static float l2Squared(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); - float squaredDistance = 0; - for (int i = 0; i < inputVector.length; i++) { - float diff = queryVector[i] - inputVector[i]; - squaredDistance += diff * diff; - } - return squaredDistance; + return VectorUtil.squareDistance(queryVector, inputVector); } private static float[] toFloat(List inputVector, VectorDataType vectorDataType) { @@ -148,20 +143,12 @@ public static float cosineSimilarity(List queryVector, KNNVectorScriptDo */ public static float cosinesimil(float[] queryVector, float[] inputVector) { requireEqualDimension(queryVector, inputVector); - float dotProduct = 0.0f; - float normQueryVector = 0.0f; - float normInputVector = 0.0f; - for (int i = 0; i < queryVector.length; i++) { - dotProduct += queryVector[i] * inputVector[i]; - normQueryVector += queryVector[i] * queryVector[i]; - normInputVector += inputVector[i] * inputVector[i]; - } - float normalizedProduct = normQueryVector * normInputVector; - if (normalizedProduct == 0) { + try { + return VectorUtil.cosine(queryVector, inputVector); + } catch (IllegalArgumentException | AssertionError e) { logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end"); return 0.0f; } - return (float) (dotProduct / (Math.sqrt(normalizedProduct))); } /** @@ -217,7 +204,6 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) { * @return L1 score */ public static float l1Norm(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); float distance = 0; for (int i = 0; i < inputVector.length; i++) { float diff = queryVector[i] - inputVector[i]; @@ -255,7 +241,6 @@ public static float l1Norm(List queryVector, KNNVectorScriptDocValues do * @return L-inf score */ public static float lInfNorm(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); float distance = 0; for (int i = 0; i < inputVector.length; i++) { float diff = queryVector[i] - inputVector[i]; @@ -293,12 +278,7 @@ public static float lInfNorm(List queryVector, KNNVectorScriptDocValues * @return dot product score */ public static float innerProduct(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); - float distance = 0; - for (int i = 0; i < inputVector.length; i++) { - distance += queryVector[i] * inputVector[i]; - } - return distance; + return VectorUtil.dotProduct(queryVector, inputVector); } /** diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 6b40f375c..3cfbe56f1 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -47,7 +47,7 @@ public void testL2() { public void testCosineSimilarity() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; - List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); + List arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f }; KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); @@ -59,7 +59,7 @@ public void testCosineSimilarity() { ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); - assertEquals(3F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F); + assertEquals(2F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F); // invalid zero vector final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f);