diff --git a/jni/include/commons.h b/jni/include/commons.h index 05367a6939..61f54233cf 100644 --- a/jni/include/commons.h +++ b/jni/include/commons.h @@ -26,6 +26,20 @@ namespace knn_jni { */ jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory * address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)} @@ -33,5 +47,13 @@ namespace knn_jni { * @param memoryAddress address to be freed. */ void freeVectorData(jlong); + + /** + * Free up the memory allocated for the data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long, long)} + * + * @param memoryAddress address to be freed. + */ + void freeByteVectorData(jlong); } } diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 958eca8ac3..7f00765d6d 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -28,11 +28,21 @@ namespace knn_jni { jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); + // Create an binary index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. + // The index is serialized to indexPathJ. + void CreateBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ); + // Load an index from indexPathJ into memory. // // Return a pointer to the loaded index jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ); + // Load a binary index from indexPathJ into memory. + // + // Return a pointer to the loaded index + jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ); + // Check if a loaded index requires shared state bool IsSharedIndexStateRequired(jlong indexPointerJ); @@ -58,6 +68,12 @@ namespace knn_jni { jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Execute a query against the binary index located in memory at indexPointerJ along with Filters + // + // Return an array of KNNQueryResults + jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index b3d55f1c1c..cce94abedb 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -71,6 +71,8 @@ namespace knn_jni { virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect ) = 0; + virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect ) = 0; virtual std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0; @@ -79,6 +81,8 @@ namespace knn_jni { // ------------------------------ MISC HELPERS ------------------------------ virtual int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) = 0; + virtual int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) = 0; + virtual int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) = 0; virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0; @@ -146,6 +150,7 @@ namespace knn_jni { std::vector Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim); std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ); int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ); + int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ); int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ); int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ); int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ); @@ -168,6 +173,7 @@ namespace knn_jni { void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val); void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf); void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); + void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); private: std::unordered_map cachedClasses; diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index e16677db70..84c1e9d897 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -18,6 +18,15 @@ #ifdef __cplusplus extern "C" { #endif + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndex @@ -42,6 +51,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex (JNIEnv *, jclass, jstring); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: loadBinaryIndex + * Signature: (Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex + (JNIEnv *, jclass, jstring); + /* * Class: org_opensearch_knn_jni_FaissService * Method: isSharedIndexStateRequired @@ -82,6 +99,14 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter (JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: queryBIndexWithFilter + * Signature: (J[BI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; + */ +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter + (JNIEnv *, jclass, jlong, jbyteArray, jint, jlongArray, jint, jintArray); + /* * Class: org_opensearch_knn_jni_FaissService * Method: free diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h index d0758d7c8c..89de76520e 100644 --- a/jni/include/org_opensearch_knn_jni_JNICommons.h +++ b/jni/include/org_opensearch_knn_jni_JNICommons.h @@ -26,6 +26,14 @@ extern "C" { JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData (JNIEnv *, jclass, jlong, jobjectArray, jlong); +/* + * Class: org_opensearch_knn_jni_JNICommons + * Method: storeVectorData + * Signature: (J[[FJJ) + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData + (JNIEnv *, jclass, jlong, jobjectArray, jlong); + /* * Class: org_opensearch_knn_jni_JNICommons * Method: freeVectorData @@ -34,6 +42,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData (JNIEnv *, jclass, jlong); +/* +* Class: org_opensearch_knn_jni_JNICommons +* Method: freeVectorData +* Signature: (J)V +*/ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData +(JNIEnv *, jclass, jlong); + #ifdef __cplusplus } #endif diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp index 3c03ac49d9..f22344b456 100644 --- a/jni/src/commons.cpp +++ b/jni/src/commons.cpp @@ -32,10 +32,32 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE return (jlong) vect; } +jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, + jobjectArray dataJ, jlong initialCapacityJ) { + std::vector *vect; + if ((long) memoryAddressJ == 0) { + vect = new std::vector(); + vect->reserve((long)initialCapacityJ); + } else { + vect = reinterpret_cast*>(memoryAddressJ); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ); + jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect); + + return (jlong) vect; +} + void knn_jni::commons::freeVectorData(jlong memoryAddressJ) { if (memoryAddressJ != 0) { auto *vect = reinterpret_cast*>(memoryAddressJ); delete vect; } } -#endif //OPENSEARCH_KNN_COMMONS_H \ No newline at end of file + +void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) { + if (memoryAddressJ != 0) { + auto *vect = reinterpret_cast*>(memoryAddressJ); + delete vect; + } +} +#endif //OPENSEARCH_KNN_COMMONS_H diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 5a0910d9a0..851cb359b1 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -22,6 +22,8 @@ #include "faiss/Index.h" #include "faiss/impl/IDSelector.h" #include "faiss/IndexIVFPQ.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/IndexBinaryHNSW.h" #include #include @@ -62,6 +64,8 @@ faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType); // Set additional parameters on faiss index void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, const std::unordered_map& parametersCpp, faiss::Index * index); +void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, + const std::unordered_map& parametersCpp, faiss::IndexBinary * index); // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); @@ -81,6 +85,91 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); +void knn_jni::faiss_wrapper::CreateBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Get space type for this index + // Binary vector only support hamming distance and faiss does not receive any space type parameter +// jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); +// std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); +// faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() * 8 / (uint64_t) dim); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_binary_factory(dim, indexDescriptionCpp.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); + SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get()); + jniUtil->DeleteLocalRef(env, subParametersJ); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); + + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index_binary(&idMap, indexPathCpp.c_str()); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; +} void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) { @@ -247,6 +336,19 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI return (jlong) indexReader; } +jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + // Skipping IO_FLAG_PQ_SKIP_SDC_TABLE because the index is read only and the sdc table is only used during ingestion + // Skipping IO_PRECOMPUTE_TABLE because it is only needed for IVFPQ-l2 and it leads to high memory consumption if + // done for each segment. Instead, we will set it later on with `setSharedIndexState` + faiss::IndexBinary* indexReader = faiss::read_index_binary(indexPathCpp.c_str(), faiss::IO_FLAG_READ_ONLY | faiss::IO_FLAG_PQ_SKIP_SDC_TABLE | faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE); + return (jlong) indexReader; +} + bool knn_jni::faiss_wrapper::IsSharedIndexStateRequired(jlong indexPointerJ) { auto * index = reinterpret_cast(indexPointerJ); return isIndexIVFPQL2(index); @@ -409,6 +511,114 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter return results; } +jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + + if (queryVectorJ == nullptr) { + throw std::runtime_error("Query Vector cannot be null"); + } + + auto *indexReader = reinterpret_cast(indexPointerJ); + + if (indexReader == nullptr) { + throw std::runtime_error("Invalid pointer to index"); + } + + // The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from + // the query point + std::vector dis(kJ); + std::vector ids(kJ); + int8_t* rawQueryvector = jniUtil->GetByteArrayElements(env, queryVectorJ, nullptr); + /* + Setting the omp_set_num_threads to 1 to make sure that no new OMP threads are getting created. + */ + omp_set_num_threads(1); + // create the filterSearch params if the filterIdsJ is not a null pointer + if(filterIdsJ != nullptr) { + jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr); + int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ); + std::unique_ptr idSelector; + if(filterIdsTypeJ == BITMAP) { + idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray)); + } else { + faiss::idx_t* batchIndices = reinterpret_cast(filteredIdsArray); + idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices)); + } + faiss::SearchParameters *searchParameters; + faiss::SearchParametersHNSW hnswParams; + faiss::SearchParametersIVF ivfParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader) { + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams.efSearch = hnswReader->hnsw.efSearch; + hnswParams.sel = idSelector.get(); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } + searchParameters = &hnswParams; + } else { + auto ivfReader = dynamic_cast(indexReader->index); + if(ivfReader) { + ivfParams.sel = idSelector.get(); + searchParameters = &ivfParams; + } + } + try { + indexReader->search(1, reinterpret_cast(rawQueryvector), kJ, dis.data(), ids.data(), searchParameters); + } catch (...) { + jniUtil->ReleaseByteArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + throw; + } + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + } else { + faiss::SearchParameters *searchParameters = nullptr; + faiss::SearchParametersHNSW hnswParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader!= nullptr && parentIdsJ != nullptr) { + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams.efSearch = hnswReader->hnsw.efSearch; + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + searchParameters = &hnswParams; + } + try { + indexReader->search(1, reinterpret_cast(rawQueryvector), kJ, dis.data(), ids.data(), searchParameters); + } catch (...) { + jniUtil->ReleaseByteArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + throw; + } + } + jniUtil->ReleaseByteArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + + // If there are not k results, the results will be padded with -1. Find the first -1, and set result size to that + // index + int resultSize = kJ; + auto it = std::find(ids.begin(), ids.end(), -1); + if (it != ids.end()) { + resultSize = it - ids.begin(); + } + + jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult"); + jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", ""); + + jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr); + + jobject result; + for(int i = 0; i < resultSize; ++i) { + result = jniUtil->NewObject(env, resultClass, allArgs, ids[i], dis[i]); + jniUtil->SetObjectArrayElement(env, results, i, result); + } + return results; +} + void knn_jni::faiss_wrapper::Free(jlong indexPointer) { auto *indexWrapper = reinterpret_cast(indexPointer); delete indexWrapper; @@ -535,6 +745,34 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, } } +void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, + const std::unordered_map& parametersCpp, faiss::IndexBinary * index) { + + std::unordered_map::const_iterator value; + if (auto * indexIvf = dynamic_cast(index)) { + if ((value = parametersCpp.find(knn_jni::NPROBES)) != parametersCpp.end()) { + indexIvf->nprobe = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + + if ((value = parametersCpp.find(knn_jni::COARSE_QUANTIZER)) != parametersCpp.end() + && indexIvf->quantizer != nullptr) { + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, value->second); + SetExtraParameters(jniUtil, env, subParametersCpp, indexIvf->quantizer); + } + } + + if (auto * indexHnsw = dynamic_cast(index)) { + + if ((value = parametersCpp.find(knn_jni::EF_CONSTRUCTION)) != parametersCpp.end()) { + indexHnsw->hnsw.efConstruction = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + + if ((value = parametersCpp.find(knn_jni::EF_SEARCH)) != parametersCpp.end()) { + indexHnsw->hnsw.efSearch = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + } +} + void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { if (auto * indexIvf = dynamic_cast(index)) { if (indexIvf->quantizer_trains_alone == 2) { diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index a1faa4894f..1c69ddd0f1 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -261,6 +261,39 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env env->DeleteLocalRef(array2dJ); } +void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect) { + + if (array2dJ == nullptr) { + throw std::runtime_error("Array cannot be null"); + } + + int numVectors = env->GetArrayLength(array2dJ); + this->HasExceptionInStack(env); + + for (int i = 0; i < numVectors; ++i) { + auto vectorArray = (jbyteArray)env->GetObjectArrayElement(array2dJ, i); + this->HasExceptionInStack(env, "Unable to get object array element"); + + if (dim != env->GetArrayLength(vectorArray)) { + throw std::runtime_error("Dimension of vectors is inconsistent"); + } + + uint8_t* vector = reinterpret_cast(env->GetByteArrayElements(vectorArray, nullptr)); + if (vector == nullptr) { + this->HasExceptionInStack(env); + throw std::runtime_error("Unable to get byte array elements"); + } + + for(int j = 0; j < dim; ++j) { + vect->push_back(vector[j]); + } + env->ReleaseByteArrayElements(vectorArray, reinterpret_cast(vector), JNI_ABORT); + } + this->HasExceptionInStack(env); + env->DeleteLocalRef(array2dJ); +} + std::vector knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) { if (arrayJ == nullptr) { @@ -302,6 +335,23 @@ int knn_jni::JNIUtil::GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectAr return dim; } +int knn_jni::JNIUtil::GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) { + + if (array2dJ == nullptr) { + throw std::runtime_error("Array cannot be null"); + } + + if (env->GetArrayLength(array2dJ) <= 0) { + return 0; + } + + auto vectorArray = (jbyteArray)env->GetObjectArrayElement(array2dJ, 0); + this->HasExceptionInStack(env); + int dim = env->GetArrayLength(vectorArray); + this->HasExceptionInStack(env); + return dim; +} + int knn_jni::JNIUtil::GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) { if (arrayJ == nullptr) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 0aa51987dd..144ed6607b 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -39,6 +39,17 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { jniUtil.Uninitialize(env); } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::CreateBinaryIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) @@ -75,6 +86,16 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEn return NULL; } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex(JNIEnv * env, jclass cls, jstring indexPathJ) +{ + try { + return knn_jni::faiss_wrapper::LoadBinaryIndex(&jniUtil, env, indexPathJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return NULL; +} + JNIEXPORT jboolean JNICALL Java_org_opensearch_knn_jni_FaissService_isSharedIndexStateRequired (JNIEnv * env, jclass cls, jlong indexPointerJ) { @@ -132,6 +153,18 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd } +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter + (JNIEnv * env, jclass cls, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + + try { + return knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; + +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ) { try { diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp index ccdd118826..0bc2e46331 100644 --- a/jni/src/org_opensearch_knn_jni_JNICommons.cpp +++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp @@ -49,6 +49,18 @@ jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) return (long)memoryAddressJ; } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls, +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) + +{ + try { + return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (long)memoryAddressJ; +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNIEnv * env, jclass cls, jlong memoryAddressJ) { @@ -58,3 +70,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNI jniUtil.CatchCppExceptionAndThrowJava(env); } } + + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData(JNIEnv * env, jclass cls, + jlong memoryAddressJ) +{ + try { + return knn_jni::commons::freeByteVectorData(memoryAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index 92532b9e26..af3ddab3e4 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -51,6 +51,12 @@ test_util::MockJNIUtil::MockJNIUtil() { (*reinterpret_cast> *>(array2dJ))) for (auto item : v) data->push_back(item); }); + ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToByteVector) + .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector* data) { + for (const auto &v : + (*reinterpret_cast> *>(array2dJ))) + for (auto item : v) data->push_back(item); + }); // arrayJ is re-interpreted as std::vector * @@ -150,6 +156,15 @@ test_util::MockJNIUtil::MockJNIUtil() { .size(); }); + // array2dJ is re-interpreted as a std::vector> * and then + // the size of the first element is returned + ON_CALL(*this, GetInnerDimensionOf2dJavaByteArray) + .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ) { + return (*reinterpret_cast> *>( + array2dJ))[0] + .size(); + }); + // arrayJ is re-interpreted as a std::vector * and the size is returned ON_CALL(*this, GetJavaFloatArrayLength) .WillByDefault([this](JNIEnv *env, jfloatArray arrayJ) { diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index 8e73a8ab0c..23ad69eb1a 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -46,6 +46,8 @@ namespace test_util { (JNIEnv * env, jobjectArray array2dJ, int dim)); MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToFloatVector, (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); + MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToByteVector, + (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); MOCK_METHOD(std::vector, ConvertJavaIntArrayToCppIntVector, (JNIEnv * env, jintArray arrayJ)); MOCK_METHOD2(ConvertJavaMapToCppMap, @@ -64,6 +66,8 @@ namespace test_util { (JNIEnv * env, jfloatArray array, jboolean* isCopy)); MOCK_METHOD(int, GetInnerDimensionOf2dJavaFloatArray, (JNIEnv * env, jobjectArray array2dJ)); + MOCK_METHOD(int, GetInnerDimensionOf2dJavaByteArray, + (JNIEnv * env, jobjectArray array2dJ)); MOCK_METHOD(jint*, GetIntArrayElements, (JNIEnv * env, jintArray array, jboolean* isCopy)); MOCK_METHOD(jlong*, GetLongArrayElements, diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 53980bbb7e..16ebbbae6a 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -49,6 +49,20 @@ class FaissService { }); } + /** + * Create a binary index for the native library The memory occupied by the vectorsAddress will be freed up during the + * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer + * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this + * issue + * + * @param ids array of ids mapping to the data passed in + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param parameters parameters to build index + */ + public static native void createBinaryIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); + /** * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer @@ -90,6 +104,14 @@ public static native void createIndexFromTemplate( */ public static native long loadIndex(String indexPath); + /** + * Load a binary index into memory + * + * @param indexPath path to index file + * @return pointer to location in memory the index resides in + */ + public static native long loadBinaryIndex(String indexPath); + /** * Determine if index contains shared state. * @@ -150,6 +172,25 @@ public static native KNNQueryResult[] queryIndexWithFilter( int[] parentIds ); + /** + * Query a binary index with filter + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param filterIds list of doc ids to include in the query result + * @param parentIds list of parent doc ids when the knn field is a nested field + * @return KNNQueryResult array of k neighbors + */ + public static native KNNQueryResult[] queryBinaryIndexWithFilter( + long indexPointer, + byte[] queryVector, + int k, + long[] filterIds, + int filterIdsType, + int[] parentIds + ); + /** * Free native memory pointer */ diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java index 90ad70c3d0..d0111b115e 100644 --- a/src/main/java/org/opensearch/knn/jni/JNICommons.java +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -47,6 +47,25 @@ public class JNICommons { */ public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity); + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + public static native long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity); + /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 20c4188197..ac0c3341da 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -12,6 +12,7 @@ package org.opensearch.knn.jni; import org.apache.commons.lang.ArrayUtils; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -21,6 +22,7 @@ * Service to distribute requests to the proper engine jni service */ public class JNIService { + private static final String FAISS_BINARY_INDEX_PREFIX = "B"; /** * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the @@ -50,7 +52,11 @@ public static void createIndex( } if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) { + FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters); + } else { + FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); + } return; } @@ -101,7 +107,12 @@ public static long loadIndex(String indexPath, Map parameters, K } if (KNNEngine.FAISS == knnEngine) { - return FaissService.loadIndex(indexPath); + //TODO heemin pass index description in parameter + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) { + return FaissService.loadBinaryIndex(indexPath); + } else { + return FaissService.loadIndex(indexPath); + } } throw new IllegalArgumentException(String.format("LoadIndex not supported for provided engine : %s", knnEngine.getName())); @@ -195,6 +206,36 @@ public static KNNQueryResult[] queryIndex( throw new IllegalArgumentException(String.format("QueryIndex not supported for provided engine : %s", knnEngine.getName())); } + /** + * Query a binary index + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap + * @return KNNQueryResult array of k neighbors + */ + public static KNNQueryResult[] queryBinaryIndex( + long indexPointer, + byte[] queryVector, + int k, + KNNEngine knnEngine, + long[] filteredIds, + int filterIdsType, + int[] parentIds + ) { + if (KNNEngine.FAISS == knnEngine) { + if (ArrayUtils.isEmpty(filteredIds) == false || parentIds != null) { + // Faiss library does not support search parameter for binary index which is required to pass filteredIds and parentIds. + throw new IllegalArgumentException("QueryBinaryIndex does not support filteredIds and parentIds"); + } + return FaissService.queryBinaryIndexWithFilter(indexPointer, queryVector, k, ArrayUtils.isEmpty(filteredIds) ? null : filteredIds, filterIdsType, parentIds); + } + throw new IllegalArgumentException(String.format("QueryBinaryIndex not supported for provided engine : %s", knnEngine.getName())); + } + /** * Free native memory pointer * diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index d6ae13e92d..864706255a 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -62,6 +62,7 @@ public class JNIServiceTests extends KNNTestCase { static TestUtils.TestData testData; static TestUtils.TestData testDataNested; private String faissMethod = "HNSW32,Flat"; + private String faissBinaryMethod = "BHNSW32"; @BeforeClass public static void setUpClass() throws IOException { @@ -647,6 +648,24 @@ public void testCreateIndex_faiss_valid() throws IOException { } } + @SneakyThrows + public void testCreateIndex_binary_faiss_valid() { + List methods = ImmutableList.of(faissBinaryMethod); + for (String method : methods) { + Path tmpFile1 = createTempFile(); + long memoryAddr = testData.loadBinaryDataToMemoryAddress(); + JNIService.createIndex( + testData.indexData.docs, + memoryAddr, + testData.indexData.getDimension(), + tmpFile1.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method), + KNNEngine.FAISS + ); + assertTrue(tmpFile1.toFile().length() > 0); + } + } + public void testLoadIndex_invalidEngine() { expectThrows(IllegalArgumentException.class, () -> JNIService.loadIndex("test", Collections.emptyMap(), KNNEngine.LUCENE)); } @@ -901,6 +920,37 @@ public void testQueryIndex_faiss_parentIds() throws IOException { } } + @SneakyThrows + public void testQueryBinaryIndex_faiss_valid() { + int k = 10; + List methods = ImmutableList.of(faissBinaryMethod); + for (String method : methods) { + Path tmpFile = createTempFile(); + long memoryAddr = testData.loadBinaryDataToMemoryAddress(); + JNIService.createIndex( + testData.indexData.docs, + memoryAddr, + testData.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method), + KNNEngine.FAISS + ); + assertTrue(tmpFile.toFile().length() > 0); + + long pointer = JNIService.loadIndex( + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + + for (byte[] query : testData.binaryQueries) { + KNNQueryResult[] results = JNIService.queryBinaryIndex(pointer, query, k, KNNEngine.FAISS, null, 0, null); + assertEquals(k, results.length); + } + } + } + private Set toParentIdSet(KNNQueryResult[] results, Map idToParentIdMap) { return Arrays.stream(results).map(result -> idToParentIdMap.get(result.getId())).collect(Collectors.toSet()); } diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index a17b537d0c..4a8b2bf0a2 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -21,6 +21,8 @@ import org.opensearch.knn.index.codec.util.SerializationMode; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.plugin.script.KNNScoringUtil; + +import java.util.Collections; import java.util.Comparator; import java.util.Random; import java.util.Set; @@ -252,11 +254,14 @@ public static PriorityQueue insertWithOverflow(PriorityQueue flattenedVectors = new ArrayList<>(indexData.vectors.length * indexData.vectors[0].length); + for (int i = 0; i < indexData.vectors.length; i++) { + for (int j = 0; j < indexData.vectors[i].length; j++) { + flattenedVectors.add(indexData.vectors[i][j]); + } + } + Collections.sort(flattenedVectors); + Float median = flattenedVectors.get(flattenedVectors.size() / 2); + + // Quantize + Packing for index data + indexBinaryData = new byte[indexData.vectors.length][(indexData.vectors[0].length + 7) / 8]; + for (int i = 0; i < indexData.vectors.length; i++) { + for (int j = 0; j < indexData.vectors[i].length; j++) { + int byteIndex = j / 8; + int bitIndex = 7 - (j % 8); + indexBinaryData[i][byteIndex] |= (indexData.vectors[i][j] >= median ? 1 : 0) << bitIndex; + } + } + + // Quantize + Packing for query data + binaryQueries = new byte[queries.length][(queries[0].length + 7) / 8]; + for (int i = 0; i < queries.length; i++) { + for (int j = 0; j < queries[i].length; j++) { + int byteIndex = j / 8; + int bitIndex = 7 - (j % 8); + binaryQueries[i][byteIndex] |= (queries[i][j] >= median ? 1 : 0) << bitIndex; + } + } + } + public long loadDataToMemoryAddress() { return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length); } + public long loadBinaryDataToMemoryAddress() { + return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length); + } + @AllArgsConstructor public static class Pair { public int[] docs;