Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor train index and create index from template APIs in JNI layer #1918

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
* Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913)
* Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920)
* Refactor train index and create index from template APIs in JNI layer [#1918](https://github.com/opensearch-project/k-NN/pull/1918)
112 changes: 106 additions & 6 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
#include "jni_util.h"
#include "faiss_methods.h"
#include <memory>
#include <vector>
#include <unordered_map>
#include <string>

namespace knn_jni {
namespace faiss_wrapper {


/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
Expand Down Expand Up @@ -61,20 +63,69 @@ class IndexService {
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters);

/**
* Create index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numIds number of vectors
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
* @param templateIndexData vector containing the template index data
*/
virtual void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData);

/**
* Train index
*
* @param index faiss index
* @param n number of vectors
* @param x memory address which is holding vector data
*/
virtual void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

/**
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dimension dimension of vectors
* @param numVectors number of vectors
* @param trainingVectors memory address which is holding vector data
* @param parameters parameters to be applied to faiss index
* @return vector containing the trained index data
*/
virtual std::vector<uint8_t> trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map<std::string, jobject>& parameters);

virtual ~IndexService() = default;
protected:
std::unique_ptr<FaissMethods> faissMethods;
};

/**
* A class to provide operations on index
* A class to provide operations on binary index
* This class should evolve to have only cpp object but not jni object
*/
class BinaryIndexService : public IndexService {
public:
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);
explicit BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);

/**
* Create binary index
*
Expand All @@ -90,7 +141,7 @@ class BinaryIndexService : public IndexService {
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
*/
virtual void createIndex(
void createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
Expand All @@ -103,11 +154,60 @@ class BinaryIndexService : public IndexService {
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) override;
virtual ~BinaryIndexService() = default;

/**
* Create binary index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numIds number of vectors
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
* @param templateIndexData vector containing the template index data
*/
void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData) override;

/**
* Train binary index
*
* @param index faiss index
* @param n number of vectors
* @param x memory address which is holding vector data
*/
void InternalTrainIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x);

/**
* Train binary index
*
* @param jniUtil jni util
* @param env jni environment
* @param metric space type for distance calculation
* @param indexDescription index description to be used by faiss index factory
* @param dimension dimension of vectors
* @param numVectors number of vectors
* @param trainingVectors memory address which is holding vector data
* @param parameters parameters to be applied to faiss index
* @return vector containing the trained index data
*/
std::vector<uint8_t> trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map<std::string, jobject>& parameters) override;


~BinaryIndexService() override = default;
};

}
}


#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
2 changes: 2 additions & 0 deletions jni/include/faiss_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class FaissMethods {
virtual faiss::IndexIDMapTemplate<faiss::IndexBinary>* indexBinaryIdMap(faiss::IndexBinary* index);
virtual void writeIndex(const faiss::Index* idx, const char* fname);
virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname);
virtual faiss::Index* readIndex(faiss::IOReader* f, int io_flags);
virtual faiss::IndexBinary* readIndexBinary(faiss::IOReader* f, int io_flags);
virtual ~FaissMethods() = default;
};

Expand Down
17 changes: 2 additions & 15 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@ namespace knn_jni {
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);
jobject parametersJ, IndexService* indexService);

// Load an index from indexPathJ into memory.
//
Expand Down Expand Up @@ -100,14 +94,7 @@ namespace knn_jni {
//
// Return the serialized representation
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

// Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with
// the vector of floats located at trainVectorsPointerJ.
//
// Return the serialized representation
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);
jlong trainVectorsPointerJ, IndexService* indexService);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
Expand Down
126 changes: 125 additions & 1 deletion jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <vector>
#include <memory>
#include <type_traits>
#include <faiss/impl/io.h>

namespace knn_jni {
namespace faiss_wrapper {
Expand Down Expand Up @@ -106,6 +107,68 @@ void IndexService::createIndex(
faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

void IndexService::createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData
) {
faiss::VectorIOReader vectorIoReader;
vectorIoReader.data = templateIndexData;

std::unique_ptr<faiss::Index> indexWriter(faissMethods->readIndex(&vectorIoReader, 0));

auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddress);
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if (numIds != numVectors) {
throw std::runtime_error("Number of vectors or IDs does not match expected values");
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, indexWriter.get());

std::unique_ptr<faiss::IndexIDMap> idMap(faissMethods->indexIdMap(indexWriter.get()));
idMap->add_with_ids(numVectors, inputVectors->data(), ids.data());

faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

void IndexService::InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexIVF*>(index)) {
if (indexIvf->quantizer_trains_alone == 2) {
InternalTrainIndex(indexIvf->quantizer, n, x);
}
indexIvf->make_direct_map();
}

if (!index->is_trained) {
index->train(n, x);
}
}

std::vector<uint8_t> IndexService::trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map<std::string, jobject>& parameters) {
// Create faiss index
std::unique_ptr<faiss::Index> index(faissMethods->indexFactory(dimension, indexDescription.c_str(), metric));

// Train index if needed
if (!index->is_trained) {
InternalTrainIndex(index.get(), numVectors, trainingVectors);
}

// Write index to a vector
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index(index.get(), &vectorIoWriter);

return std::vector<uint8_t>(vectorIoWriter.data.begin(), vectorIoWriter.data.end());
}



BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}

void BinaryIndexService::createIndex(
Expand Down Expand Up @@ -160,5 +223,66 @@ void BinaryIndexService::createIndex(
faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

void BinaryIndexService::createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData
) {
faiss::VectorIOReader vectorIoReader;
vectorIoReader.data = templateIndexData;

std::unique_ptr<faiss::IndexBinary> indexWriter(faissMethods->readIndexBinary(&vectorIoReader, 0));

auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddress);
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
if (numIds != numVectors) {
throw std::runtime_error("Number of vectors or IDs does not match expected values");
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, indexWriter.get());

std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(indexWriter.get()));
idMap->add_with_ids(numVectors, inputVectors->data(), ids.data());

faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

void BinaryIndexService::InternalTrainIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
if (!indexIvf->is_trained) {
indexIvf->train(n, reinterpret_cast<const uint8_t*>(x));
}
}
if (!index->is_trained) {
index->train(n, reinterpret_cast<const uint8_t*>(x));
}
}

std::vector<uint8_t> BinaryIndexService::trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map<std::string, jobject>& parameters) {
// Convert Java parameters to C++ parameters
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::index_binary_factory(dimension, indexDescription.c_str()));

// Train the index if it is not already trained
if (!indexWriter->is_trained) {
InternalTrainIndex(indexWriter.get(), numVectors, trainingVectors);
}

// Serialize the trained index to a byte array
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index_binary(indexWriter.get(), &vectorIoWriter);

// Convert the serialized data to a std::vector<uint8_t>
std::vector<uint8_t> trainedIndexData(vectorIoWriter.data.begin(), vectorIoWriter.data.end());

return trainedIndexData;
}
} // namespace faiss_wrapper
} // namesapce knn_jni
} // namespace knn_jni
8 changes: 8 additions & 0 deletions jni/src/faiss_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@ faiss::IndexIDMapTemplate<faiss::IndexBinary>* FaissMethods::indexBinaryIdMap(fa
void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) {
faiss::write_index(idx, fname);
}

void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) {
faiss::write_index_binary(idx, fname);
}

faiss::Index* FaissMethods::readIndex(faiss::IOReader* f, int io_flags) {
return faiss::read_index(f, io_flags);
}

faiss::IndexBinary* FaissMethods::readIndexBinary(faiss::IOReader* f, int io_flags) {
return faiss::read_index_binary(f, io_flags);
}
} // namespace faiss_wrapper
} // namesapce knn_jni
Loading
Loading