Skip to content

Commit

Permalink
Add more tests from feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jun 13, 2024
1 parent a04cd8e commit 62a7dae
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 23 deletions.
4 changes: 2 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader) {
// Query param ef_search supersedes ef_search provided during index setting.
hnswParams.efSearch = getQueryEfSearch(env, jniUtil, methodParams, hnswReader->hnsw.efSearch);
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
hnswParams.sel = idSelector.get();
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
Expand Down Expand Up @@ -670,7 +670,7 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr) {
// Query param ef_search supersedes ef_search provided during index setting.
hnswParams.efSearch = getQueryEfSearch(env, jniUtil, methodParams, hnswReader->hnsw.efSearch);
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
Expand Down
158 changes: 141 additions & 17 deletions jni/tests/faiss_wrapper_unit_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include "test_util.h"
#include "faiss/IndexHNSW.h"
#include "faiss/IndexIDMap.h"
#include "faiss/index_factory.h"
#include "faiss/IndexIVFPQ.h"

using ::testing::NiceMock;

Expand All @@ -30,14 +32,15 @@ struct MockIndex : faiss::IndexHNSW {
}
};


struct MockIdMap : faiss::IndexIDMap {
mutable idx_t nCalled;
mutable const float *xCalled;
mutable idx_t kCalled;
mutable float *distancesCalled;
mutable idx_t *labelsCalled;
mutable const faiss::SearchParametersHNSW *paramsCalled;
mutable idx_t nCalled{};
mutable const float *xCalled{};
mutable int kCalled{};
mutable float radiusCalled{};
mutable float *distancesCalled{};
mutable idx_t *labelsCalled{};
mutable const faiss::SearchParametersHNSW *paramsCalled{};
mutable faiss::RangeSearchResult *resCalled{};

explicit MockIdMap(MockIndex *index) : faiss::IndexIDMapTemplate<faiss::Index>(index) {
}
Expand All @@ -57,32 +60,65 @@ struct MockIdMap : faiss::IndexIDMap {
paramsCalled = dynamic_cast<const faiss::SearchParametersHNSW *>(params);
}

void range_search(
idx_t n,
const float *x,
float radius,
faiss::RangeSearchResult *res,
const faiss::SearchParameters *params) const override {
nCalled = n;
xCalled = x;
radiusCalled = radius;
resCalled = res;
paramsCalled = dynamic_cast<const faiss::SearchParametersHNSW *>(params);
}

void resetMock() const {
nCalled = 0;
xCalled = nullptr;
kCalled = 0;
radiusCalled = 0.0;
distancesCalled = nullptr;
labelsCalled = nullptr;
resCalled = nullptr;
paramsCalled = nullptr;
}
};

struct QueryIndexHNSWTestInput {
string description;
std::string description;
int k;
int efSearch;
int filterIdType;
bool filterIdsPresent;
bool parentIdsPresent;
};


struct RangeSearchTestInput {
std::string description;
float radius;
int efSearch;
int filterIdType;
bool filterIdsPresent;
bool parentIdsPresent;
};

class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam<QueryIndexHNSWTestInput> {
public:
FaissWrappeterParametrizedTestFixture() : index_(3), id_map_(&index_) {
index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere
};
}

protected:
MockIndex index_;
MockIdMap id_map_;
};

class FaissWrapperParametrizedRangeSearchTestFixture : public testing::TestWithParam<RangeSearchTestInput> {
public:
FaissWrapperParametrizedRangeSearchTestFixture() : index_(3), id_map_(&index_) {
index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere
}

protected:
MockIndex index_;
Expand All @@ -93,14 +129,13 @@ namespace query_index_test {

std::unordered_map<std::string, jobject> methodParams;


TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexHNSWTests) {
//Given
// Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;


QueryIndexHNSWTestInput const &input = GetParam();
std::cout << "Running test: " << input.description << std::endl;
float query[] = {1.2, 2.3, 3.4};

int efSearch = input.efSearch;
Expand Down Expand Up @@ -137,7 +172,7 @@ namespace query_index_test {
reinterpret_cast<jfloatArray>(&query), input.k, reinterpret_cast<jobject>(&methodParams),
reinterpret_cast<jintArray>(parentIdPtr));

//Then
// Then
int actualEfSearch = id_map_.paramsCalled->efSearch;
// Asserting the captured argument
EXPECT_EQ(input.k, id_map_.kCalled);
Expand Down Expand Up @@ -165,11 +200,10 @@ namespace query_index_test {
namespace query_index_with_filter_test {

TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexWithFilterHNSWTests) {
//Given
// Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;


QueryIndexHNSWTestInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

Expand Down Expand Up @@ -218,7 +252,7 @@ namespace query_index_with_filter_test {
input.filterIdType,
reinterpret_cast<jintArray>(parentIdPtr));

//Then
// Then
int actualEfSearch = id_map_.paramsCalled->efSearch;
// Asserting the captured argument
EXPECT_EQ(input.k, id_map_.kCalled);
Expand Down Expand Up @@ -249,3 +283,93 @@ namespace query_index_with_filter_test {
)
);
}

namespace range_search_test {

TEST_P(FaissWrapperParametrizedRangeSearchTestFixture, RangeSearchHNSWTests) {
// Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

RangeSearchTestInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};
float radius = input.radius;
int maxResultWindow = 100; // Set your max result window

std::unordered_map<std::string, jobject> methodParams;
int efSearch = input.efSearch;
int expectedEfSearch = 100; // default set in mock
if (efSearch != -1) {
expectedEfSearch = input.efSearch;
methodParams[knn_jni::EF_SEARCH] = reinterpret_cast<jobject>(&efSearch);
}

std::vector<int> *parentIdPtr = nullptr;
if (input.parentIdsPresent) {
std::vector<int> parentId;
parentId.reserve(2);
parentId.push_back(1);
parentId.push_back(2);
parentIdPtr = &parentId;

EXPECT_CALL(mockJNIUtil,
GetJavaIntArrayLength(
jniEnv, reinterpret_cast<jintArray>(parentIdPtr)))
.WillOnce(testing::Return(parentId.size()));

EXPECT_CALL(mockJNIUtil,
GetIntArrayElements(
jniEnv, reinterpret_cast<jintArray>(parentIdPtr), nullptr))
.WillOnce(testing::Return(new int[2]{1, 2}));
}

std::vector<long> filter;
std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
filter.reserve(2);
filter.push_back(1);
filter.push_back(2);
filterptr = &filter;
}

// When
knn_jni::faiss_wrapper::RangeSearchWithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&id_map_),
reinterpret_cast<jfloatArray>(&query), radius, reinterpret_cast<jobject>(&methodParams),
maxResultWindow,
reinterpret_cast<jlongArray>(filterptr),
input.filterIdType,
reinterpret_cast<jintArray>(parentIdPtr));

// Then
int actualEfSearch = id_map_.paramsCalled->efSearch;
// Asserting the captured argument
EXPECT_EQ(expectedEfSearch, actualEfSearch);
if (input.parentIdsPresent) {
faiss::IDGrouper *grouper = id_map_.paramsCalled->grp;
EXPECT_TRUE(grouper != nullptr);
}
if (input.filterIdsPresent) {
faiss::IDSelector *sel = id_map_.paramsCalled->sel;
EXPECT_TRUE(sel != nullptr);
}
id_map_.resetMock();
}

INSTANTIATE_TEST_CASE_P(
RangeSearchHNSWTests,
FaissWrapperParametrizedRangeSearchTestFixture,
::testing::Values(
RangeSearchTestInput{"algoParams present, parent absent, filter absent", 10.0f, 200, 0, false, false},
RangeSearchTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10.0f, 200, 1, false, false},
RangeSearchTestInput{"algoParams absent, parent absent, filter present", 10.0f, -1, 0, true, false},
RangeSearchTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10.0f, -1, 1, true, false},
RangeSearchTestInput{"algoParams present, parent present, filter absent", 10.0f, 200, 0, false, true},
RangeSearchTestInput{"algoParams present, parent present, filter absent, filter type 1", 10.0f, 150, 1, false, true},
RangeSearchTestInput{"algoParams absent, parent present, filter present", 10.0f, -1, 0, true, true},
RangeSearchTestInput{"algoParams absent, parent present, filter present, filter type 1", 10.0f, -1, 1, true, true}
)
);
}

34 changes: 30 additions & 4 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyFloat;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
Expand Down Expand Up @@ -763,12 +762,29 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() {
final float radius = 0.5f;
final int maxResults = 1000;
jniServiceMockedStatic.when(
() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), any(), anyInt(), any(), anyInt(), any())
() -> JNIService.radiusQueryIndex(
anyLong(),
eq(queryVector),
eq(radius),
eq(HNSW_METHOD_PARAMETERS),
any(),
eq(maxResults),
any(),
anyInt(),
any()
)
).thenReturn(getKNNQueryResults());
KNNQuery.Context context = mock(KNNQuery.Context.class);
when(context.getMaxResultWindow()).thenReturn(maxResults);

final KNNQuery query = new KNNQuery(FIELD_NAME, queryVector, INDEX_NAME, null).radius(radius).kNNQueryContext(context);
final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(queryVector)
.radius(radius)
.indexName(INDEX_NAME)
.context(context)
.methodParameters(HNSW_METHOD_PARAMETERS)
.build();
final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost);

Expand Down Expand Up @@ -807,7 +823,17 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() {
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
assertNotNull(knnScorer);
jniServiceMockedStatic.verify(
() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), any(), anyInt(), any(), anyInt(), any())
() -> JNIService.radiusQueryIndex(
anyLong(),
eq(queryVector),
eq(radius),
eq(HNSW_METHOD_PARAMETERS),
any(),
eq(maxResults),
any(),
anyInt(),
any()
)
);

final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
Expand Down

0 comments on commit 62a7dae

Please sign in to comment.