Skip to content

Commit

Permalink
Integrates FAISS iterative builds with NativeEngines990KnnVectorsFormat
Browse files Browse the repository at this point in the history
Changes include reusing the same vector buffer in the JNI layer

Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed Aug 20, 2024
1 parent 4562bb6 commit 1536f17
Show file tree
Hide file tree
Showing 51 changed files with 1,735 additions and 1,113 deletions.
17 changes: 15 additions & 2 deletions jni/include/commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@ namespace knn_jni {
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector
* without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating
* and deallocating when the memory address needs to be reused.
*
* CAUTION: The behavior is undefined if the memory address is deallocated and the method is called
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D float array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @param append whether to append or start from index 0 when called subsequently with the same address
* @return memory address of std::vector<float> where the data is stored.
*/
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);

/**
* This is utility function that can be used to store data in native memory. This function will allocate memory for
Expand All @@ -33,12 +40,18 @@ namespace knn_jni {
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
* will throw Exception.
*
* append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector
* without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating
* and deallocating when the memory address needs to be reused.
*
* CAUTION: The behavior is undefined if the memory address is deallocated and the method is called
*
* @param memoryAddress The address of the memory location where data will be stored.
* @param data 2D byte array containing data to be stored in native memory.
* @param initialCapacity The initial capacity of the memory location.
* @return memory address of std::vector<uint8_t> where the data is stored.
*/
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean);

/**
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
Expand Down
6 changes: 3 additions & 3 deletions jni/include/org_opensearch_knn_jni_JNICommons.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ extern "C" {
/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)
* Signature: (J[[FJJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);
(JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_JNICommons
* Method: storeVectorData
* Signature: (J[[FJJ)
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData
(JNIEnv *, jclass, jlong, jobjectArray, jlong);
(JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_JNICommons
Expand Down
14 changes: 12 additions & 2 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,39 @@
#include "commons.h"

jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
std::vector<float> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<float>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<float>*>(memoryAddressJ);
}

if (appendJ == JNI_FALSE) {
vect->clear();
}

int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect);

return (jlong) vect;
}

jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ,
jobjectArray dataJ, jlong initialCapacityJ) {
jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) {
std::vector<uint8_t> *vect;
if ((long) memoryAddressJ == 0) {
vect = new std::vector<uint8_t>();
vect->reserve((long)initialCapacityJ);
} else {
vect = reinterpret_cast<std::vector<uint8_t>*>(memoryAddressJ);
}

if (appendJ == JNI_FALSE) {
vect->clear();
}

int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ);
jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect);

Expand Down
42 changes: 22 additions & 20 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,30 @@ jlong IndexService::initIndex(
std::unordered_map<std::string, jobject> parameters
) {
// Create index using Faiss factory method
std::unique_ptr<faiss::Index> indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));
std::unique_ptr<faiss::Index> index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(threadCount != 0) {
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, indexWriter.get());
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, index.get());

// Check that the index does not need to be trained
if(!indexWriter->is_trained) {
if(!index->is_trained) {
throw std::runtime_error("Index is not trained");
}

std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(indexWriter.get()));
std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
indexWriter.release();

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}

Expand All @@ -113,13 +118,13 @@ void IndexService::insertToIndex(
std::vector<int64_t> & ids,
jlong idMapAddress
) {
// Read vectors from memory address (unique ptr since we want to remove from memory after use)
// Read vectors from memory address
std::vector<float> * inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddress);

// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if(numVectors == 0) {
return;
throw std::runtime_error("Number of vectors cannot be 0");
}

if (numIds != numVectors) {
Expand Down Expand Up @@ -147,11 +152,8 @@ void IndexService::writeIndex(
// Write the index to disk
faissMethods->writeIndex(idMap.get(), indexPath.c_str());
} catch(std::exception &e) {
delete idMap->index;
throw std::runtime_error("Failed to write index to disk");
}
// Free the memory used by the index
delete idMap->index;
}

BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}
Expand All @@ -175,25 +177,29 @@ jlong BinaryIndexService::initIndex(
std::unordered_map<std::string, jobject> parameters
) {
// Create index using Faiss factory method
std::unique_ptr<faiss::IndexBinary> indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str()));

std::unique_ptr<faiss::IndexBinary> index(faissMethods->indexBinaryFactory(dim, indexDescription.c_str()));
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(threadCount != 0) {
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, indexWriter.get());
SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, index.get());

// Check that the index does not need to be trained
if(!indexWriter->is_trained) {
if(!index->is_trained) {
throw std::runtime_error("Index is not trained");
}

std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(indexWriter.get()));
std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called
idMap->own_fields = true;

allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
indexWriter.release();

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}

Expand Down Expand Up @@ -240,12 +246,8 @@ void BinaryIndexService::writeIndex(
// Write the index to disk
faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
} catch(std::exception &e) {
delete idMap->index;
throw std::runtime_error("Failed to write index to disk");
}

// Free the memory used by the index
delete idMap->index;
}

} // namespace faiss_wrapper
Expand Down
9 changes: 4 additions & 5 deletions jni/src/nmslib_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,14 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J
}
jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT);

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
std::unique_ptr<similarity::Index<float>> index;
index.reset(similarity::MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceTypeCpp, *(space), dataset));
index->CreateIndex(similarity::AnyParams(indexParameters));
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
//commons::freeVectorData(vectorsAddressJ);
delete inputVectors;

std::unique_ptr<similarity::Index<float>> index;
index.reset(similarity::MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceTypeCpp, *(space), dataset));
index->CreateIndex(similarity::AnyParams(indexParameters));
index->SaveIndex(indexPathCpp);

for (auto & it : dataset) {
Expand Down
2 changes: 0 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &indexService);
delete reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
} catch (...) {
// NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH!
jniUtil.CatchCppExceptionAndThrowJava(env);
Expand All @@ -90,7 +89,6 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIn
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods));
knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &binaryIndexService);
delete reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
} catch (...) {
// NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH!
jniUtil.CatchCppExceptionAndThrowJava(env);
Expand Down
8 changes: 4 additions & 4 deletions jni/src/org_opensearch_knn_jni_JNICommons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,23 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {


JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ)
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ)

{
try {
return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ);
return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return (long)memoryAddressJ;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls,
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ)
jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ)

{
try {
return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ);
return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
Loading

0 comments on commit 1536f17

Please sign in to comment.