Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor train index and create index from template APIs in JNI layer #1918

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactor jni create template index
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
  • Loading branch information
junqiu-lei committed Aug 2, 2024
commit b661bac682da2b4f94878374094d843b049edd49
58 changes: 55 additions & 3 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
#include "jni_util.h"
#include "faiss_methods.h"
#include <memory>
#include <vector>
#include <unordered_map>
#include <string>

namespace knn_jni {
namespace faiss_wrapper {


/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
Expand Down Expand Up @@ -61,20 +63,46 @@ class IndexService {
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters);

/**
* Create index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numIds number of vectors
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
* @param templateIndexData vector containing the template index data
*/
virtual void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData);

virtual ~IndexService() = default;
protected:
std::unique_ptr<FaissMethods> faissMethods;
};

/**
* A class to provide operations on index
* A class to provide operations on binary index
* This class should evolve to have only cpp object but not jni object
*/
class BinaryIndexService : public IndexService {
public:
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);

/**
* Create binary index
*
Expand Down Expand Up @@ -103,11 +131,35 @@ class BinaryIndexService : public IndexService {
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) override;

/**
* Create binary index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numIds number of vectors
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
* @param templateIndexData vector containing the template index data
*/
virtual void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData);

virtual ~BinaryIndexService() = default;
};

}
}


#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
2 changes: 2 additions & 0 deletions jni/include/faiss_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class FaissMethods {
virtual faiss::IndexIDMapTemplate<faiss::IndexBinary>* indexBinaryIdMap(faiss::IndexBinary* index);
virtual void writeIndex(const faiss::Index* idx, const char* fname);
virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname);
virtual faiss::Index* readIndex(faiss::IOReader* f, int io_flags);
virtual faiss::IndexBinary* readIndexBinary(faiss::IOReader* f, int io_flags);
virtual ~FaissMethods() = default;
};

Expand Down
8 changes: 1 addition & 7 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@ namespace knn_jni {
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);
jobject parametersJ, IndexService* indexService);

// Load an index from indexPathJ into memory.
//
Expand Down
65 changes: 64 additions & 1 deletion jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <vector>
#include <memory>
#include <type_traits>
#include <faiss/impl/io.h>

namespace knn_jni {
namespace faiss_wrapper {
Expand Down Expand Up @@ -106,6 +107,37 @@ void IndexService::createIndex(
faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

void IndexService::createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData
) {
faiss::VectorIOReader vectorIoReader;
vectorIoReader.data = templateIndexData;

std::unique_ptr<faiss::Index> indexWriter(faissMethods->readIndex(&vectorIoReader, 0));

auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddress);
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if (numIds != numVectors) {
throw std::runtime_error("Number of vectors or IDs does not match expected values");
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, indexWriter.get());

std::unique_ptr<faiss::IndexIDMap> idMap(faissMethods->indexIdMap(indexWriter.get()));
idMap->add_with_ids(numVectors, inputVectors->data(), ids.data());

faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}

void BinaryIndexService::createIndex(
Expand Down Expand Up @@ -160,5 +192,36 @@ void BinaryIndexService::createIndex(
faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

void BinaryIndexService::createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData
) {
faiss::VectorIOReader vectorIoReader;
vectorIoReader.data = templateIndexData;

std::unique_ptr<faiss::IndexBinary> indexWriter(faissMethods->readIndexBinary(&vectorIoReader, 0));

auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddress);
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
if (numIds != numVectors) {
throw std::runtime_error("Number of vectors or IDs does not match expected values");
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, indexWriter.get());

std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(indexWriter.get()));
idMap->add_with_ids(numVectors, inputVectors->data(), ids.data());

faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

} // namespace faiss_wrapper
} // namesapce knn_jni
} // namespace knn_jni
7 changes: 6 additions & 1 deletion jni/src/faiss_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) {
void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) {
faiss::write_index_binary(idx, fname);
}

faiss::Index* FaissMethods::readIndex(faiss::IOReader* f, int io_flags) {
return faiss::read_index(f, io_flags);
}
faiss::IndexBinary* FaissMethods::readIndexBinary(faiss::IOReader* f, int io_flags) {
return faiss::read_index_binary(f, io_flags);
}
} // namespace faiss_wrapper
} // namesapce knn_jni
108 changes: 11 additions & 97 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <vector>

// Defines type of IDSelector
enum FilterIdsSelectorType{
enum FilterIdsSelectorType {
BITMAP = 0, BATCH = 1,
};
namespace faiss {
Expand Down Expand Up @@ -76,7 +76,7 @@ void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const
// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

// Concerts the FilterIds to BitMap
// Converts the FilterIds to BitMap
void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector);

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap);
Expand Down Expand Up @@ -161,7 +161,7 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
jbyteArray templateIndexJ, jobject parametersJ, IndexService* indexService) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}
Expand Down Expand Up @@ -192,108 +192,22 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
std::vector<uint8_t> templateIndexData(indexBytesJ, indexBytesJ + indexBytesCount);
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiply of 8");
}
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);
// Convert ids
auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
int64_t vectorsAddress = (int64_t)vectorsAddressJ;
std::string indexPathCpp = jniUtil->ConvertJavaStringToCppString(env, indexPathJ);

// Create faiss index
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0));

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, reinterpret_cast<const uint8_t*>(inputVectors->data()), idVector.data());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index_binary(&idMap, indexPathCpp.c_str());
// Create index using IndexService
indexService->createIndexFromTemplate(jniUtil, env, dim, numIds, vectorsAddress, ids, indexPathCpp, parametersCpp, templateIndexData);
}

jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
Expand Down Expand Up @@ -674,7 +588,7 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
// Add extra parameters that can't be configured with the index factory
if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
Expand Down
8 changes: 6 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand All @@ -99,7 +101,9 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods));
CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &binaryIndexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
Loading