Skip to content

Commit

Permalink
Add jni interface to use a binary hnsw index with faiss
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jun 13, 2024
1 parent 623b610 commit 6b6d483
Show file tree
Hide file tree
Showing 23 changed files with 726 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
### Maintenance
### Refactoring
* Add jni interface to use a binary hnsw index with faiss [#1747](https://github.com/opensearch-project/k-NN/pull/1747)

22 changes: 22 additions & 0 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,34 @@ namespace knn_jni {
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeVectorData(jlong);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
* address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long, long)}
*
* @param memoryAddress address to be freed.
*/
void freeByteVectorData(jlong);
}
}
16 changes: 16 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,21 @@ namespace knn_jni {
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Create an binary index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
void CreateBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ);

// Load an index from indexPathJ into memory.
//
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Load a binary index from indexPathJ into memory.
//
// Return a pointer to the loaded index
jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Check if a loaded index requires shared state
bool IsSharedIndexStateRequired(jlong indexPointerJ);

Expand All @@ -58,6 +68,12 @@ namespace knn_jni {
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ,
jint filterIdsTypeJ, jintArray parentIdsJ);

// Execute a query against the binary index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);

Expand Down
6 changes: 6 additions & 0 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ namespace knn_jni {

virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<float> *vect ) = 0;
virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
int dim, std::vector<uint8_t> *vect ) = 0;

virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;

Expand All @@ -79,6 +81,8 @@ namespace knn_jni {
// ------------------------------ MISC HELPERS ------------------------------
virtual int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) = 0;

virtual int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) = 0;

virtual int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) = 0;

virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0;
Expand Down Expand Up @@ -146,6 +150,7 @@ namespace knn_jni {
std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim);
std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ);
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ);
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
Expand All @@ -168,6 +173,7 @@ namespace knn_jni {
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand Down
25 changes: 25 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
#ifdef __cplusplus
extern "C" {
#endif

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createBinaryIndex
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndex
Expand All @@ -42,6 +51,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndex
* Signature: (Ljava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: isSharedIndexStateRequired
Expand Down Expand Up @@ -82,6 +99,14 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryBIndexWithFilter
* Signature: (J[BI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
(JNIEnv *, jclass, jlong, jbyteArray, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ extern "C" {
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeVectorData
Expand All @@ -34,6 +42,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: freeVectorData
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData
(JNIEnv *, jclass, jlong);

#ifdef __cplusplus
}
#endif
Expand Down
24 changes: 23 additions & 1 deletion jni/src/commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,32 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE
return (jlong) vect;
}

jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
std::vector<uint8_t> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<uint8_t>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddressJ);
}
int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect);

return (jlong) vect;
}

void knn_jni::commons::freeVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
delete vect;
}
}
#endif //OPENSEARCH_KNN_COMMONS_H

void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) {
if (memoryAddressJ != 0) {
auto *vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddressJ);
delete vect;
}
}
#endif //OPENSEARCH_KNN_COMMONS_H
Loading

0 comments on commit 6b6d483

Please sign in to comment.