Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Nprobe as a method parameter in knn query #1758

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
if (methodParamsJ != nullptr) {
methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ);
}

// The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from
// the query point
std::vector<float> dis(kJ);
Expand Down Expand Up @@ -357,7 +356,10 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
auto ivfFlatReader = dynamic_cast<const faiss::IndexIVFFlat*>(indexReader->index);

if(ivfReader || ivfFlatReader) {
int indexNprobe = ivfReader == nullptr ? ivfReader->nprobe : ivfFlatReader->nprobe;
ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe);
ivfParams.sel = idSelector.get();
searchParameters = &ivfParams;
}
Expand All @@ -373,17 +375,25 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
faiss::SearchParametersIVF ivfParams;
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr) {
if(hnswReader != nullptr) {
// Query param efsearch supersedes ef_search provided during index setting.
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
searchParameters = &hnswParams;
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
if (ivfReader) {
int indexNprobe = ivfReader->nprobe;
ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe);
searchParameters = &ivfParams;
}
}
try {
indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters);
Expand Down
174 changes: 144 additions & 30 deletions jni/tests/faiss_wrapper_unit_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "faiss_wrapper.h"

#include <vector>
#include <faiss/IndexIVFFlat.h>
#include <faiss/IndexIVFPQ.h>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
Expand All @@ -30,6 +32,46 @@ struct MockIndex : faiss::IndexHNSW {
}
};

struct MockIVFIndex : faiss::IndexIVFFlat {
explicit MockIVFIndex() = default;
};

struct MockIVFIdMap : faiss::IndexIDMap {
mutable idx_t nCalled{};
mutable const float *xCalled{};
mutable idx_t kCalled{};
mutable float *distancesCalled{};
mutable idx_t *labelsCalled{};
mutable const faiss::SearchParametersIVF *paramsCalled{};

explicit MockIVFIdMap(MockIVFIndex *index) : faiss::IndexIDMapTemplate<faiss::Index>(index) {
}

void search(
idx_t n,
const float *x,
idx_t k,
float *distances,
idx_t *labels,
const faiss::SearchParameters *params) const override {
nCalled = n;
xCalled = x;
kCalled = k;
distancesCalled = distances;
labelsCalled = labels;
paramsCalled = dynamic_cast<const faiss::SearchParametersIVF *>(params);
}

void resetMock() const {
nCalled = 0;
xCalled = nullptr;
kCalled = 0;
distancesCalled = nullptr;
labelsCalled = nullptr;
paramsCalled = nullptr;
}
};

struct MockIdMap : faiss::IndexIDMap {
mutable idx_t nCalled{};
mutable const float *xCalled{};
Expand Down Expand Up @@ -83,13 +125,14 @@ struct MockIdMap : faiss::IndexIDMap {
}
};

struct QueryIndexHNSWTestInput {
std::string description;
struct QueryIndexInput {
string description;
int k;
int efSearch;
int filterIdType;
bool filterIdsPresent;
bool parentIdsPresent;
int efSearch;
int nprobe;
};

struct RangeSearchTestInput {
Expand All @@ -101,9 +144,9 @@ struct RangeSearchTestInput {
bool parentIdsPresent;
};

class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam<QueryIndexHNSWTestInput> {
class FaissWrapperParametrizedTestFixture : public testing::TestWithParam<QueryIndexInput> {
shatejas marked this conversation as resolved.
Show resolved Hide resolved
public:
FaissWrappeterParametrizedTestFixture() : index_(3), id_map_(&index_) {
FaissWrapperParametrizedTestFixture() : index_(3), id_map_(&index_) {
index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere
}

Expand All @@ -123,16 +166,25 @@ class FaissWrapperParametrizedRangeSearchTestFixture : public testing::TestWithP
MockIdMap id_map_;
};

namespace query_index_test {
class FaissWrapperIVFQueryTestFixture : public testing::TestWithParam<QueryIndexInput> {
public:
FaissWrapperIVFQueryTestFixture() : ivf_id_map_(&ivf_index_) {
ivf_index_.nprobe = 100;
};

std::unordered_map<std::string, jobject> methodParams;
protected:
MockIVFIndex ivf_index_;
MockIVFIdMap ivf_id_map_;
};

TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexHNSWTests) {
// Given
namespace query_index_test {

TEST_P(FaissWrapperParameterizedTestFixture, QueryIndexHNSWTests) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexHNSWTestInput const &input = GetParam();
QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

int efSearch = input.efSearch;
Expand Down Expand Up @@ -184,24 +236,23 @@ namespace query_index_test {

INSTANTIATE_TEST_CASE_P(
QueryIndexHNSWTests,
FaissWrappeterParametrizedTestFixture,
FaissWrapperParametrizedTestFixture,
::testing::Values(
QueryIndexHNSWTestInput{"algoParams present, parent absent", 10, 200, 0, false, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent", 10, -1, 0, false, false},
QueryIndexHNSWTestInput{"algoParams present, parent present", 10, 200, 0, false, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present", 10, -1, 0, false, true}
QueryIndexInput {"algoParams present, parent absent", 10, 0, false, false, 200, -1 },
QueryIndexInput {"algoParams present, parent absent", 10, 0, false, false, -1, -1 },
QueryIndexInput {"algoParams present, parent present", 10, 0, false, true, 200, -1 },
QueryIndexInput {"algoParams absent, parent present", 10, 0, false, true, -1, -1 }
)
);
}

namespace query_index_with_filter_test {

TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexWithFilterHNSWTests) {
// Given
TEST_P(FaissWrapperParameterizedTestFixture, QueryIndexWithFilterHNSWTests) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexHNSWTestInput const &input = GetParam();
QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

std::vector<int> *parentIdPtr = nullptr;
Expand Down Expand Up @@ -267,23 +318,23 @@ namespace query_index_with_filter_test {

INSTANTIATE_TEST_CASE_P(
QueryIndexWithFilterHNSWTests,
FaissWrappeterParametrizedTestFixture,
FaissWrapperParametrizedTestFixture,
::testing::Values(
QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent", 10, 200, 0, false, false},
QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10, 200, 1, false, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present", 10, -1, 0, true, false},
QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10, -1, 1, true, false},
QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent", 10, 200, 0, false, true},
QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent, filter type 1", 10, 150, 1, false, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present", 10, -1, 0, true, true},
QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present, filter type 1",10, -1, 1, true, true}
QueryIndexInput { "algoParams present, parent absent, filter absent", 10, 0, false, false, 200, -1 },
QueryIndexInput { "algoParams present, parent absent, filter absent, filter type 1", 10, 1, false, false, 200, -1},
QueryIndexInput { "algoParams absent, parent absent, filter present", 10, 0, true, false, -1, -1},
QueryIndexInput { "algoParams absent, parent absent, filter present, filter type 1", 10, 1, true, false, -1, -1},
QueryIndexInput { "algoParams present, parent present, filter absent", 10, 0, false, true, 200, -1 },
QueryIndexInput { "algoParams present, parent present, filter absent, filter type 1", 10, 1, false, true, 150, -1},
QueryIndexInput { "algoParams absent, parent present, filter present", 10, 0, true, true, -1, -1},
QueryIndexInput { "algoParams absent, parent present, filter present, filter type 1",10, 1, true, true, -1, -1 }
)
);
}

namespace range_search_test {

TEST_P(FaissWrapperParametrizedRangeSearchTestFixture, RangeSearchHNSWTests) {
TEST_P(FaissWrapperParameterizedRangeSearchTestFixture, RangeSearchHNSWTests) {
// Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
Expand Down Expand Up @@ -323,6 +374,7 @@ namespace range_search_test {
std::vector<long> filter;
std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
std::vector<long> filter;
filter.reserve(2);
filter.push_back(1);
filter.push_back(2);
Expand Down Expand Up @@ -356,7 +408,7 @@ namespace range_search_test {

INSTANTIATE_TEST_CASE_P(
RangeSearchHNSWTests,
FaissWrapperParametrizedRangeSearchTestFixture,
FaissWrapperParameterizedRangeSearchTestFixture,
::testing::Values(
RangeSearchTestInput{"algoParams present, parent absent, filter absent", 10.0f, 200, 0, false, false},
RangeSearchTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10.0f, 200, 1, false, false},
Expand All @@ -370,3 +422,65 @@ namespace range_search_test {
);
}

namespace query_index_with_filter_test_ivf {

TEST_P(FaissWrapperIVFQueryTestFixture, QueryIndexIVFTest) {
//Given
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

QueryIndexInput const &input = GetParam();
float query[] = {1.2, 2.3, 3.4};

int nprobe = input.nprobe;
int expectedNprobe = 100; //default set in mock
std::unordered_map<std::string, jobject> methodParams;
if (nprobe != -1) {
expectedNprobe = input.nprobe;
methodParams[knn_jni::NPROBES] = reinterpret_cast<jobject>(&nprobe);
}

std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
std::vector<long> filter;
std::vector<long> *filterptr = nullptr;
if (input.filterIdsPresent) {
std::vector<long> filter;
filter.reserve(2);
filter.push_back(1);
filter.push_back(2);
filterptr = &filter;
}
}
// When
knn_jni::faiss_wrapper::QueryIndex_WithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&ivf_id_map_),
reinterpret_cast<jfloatArray>(&query), input.k, reinterpret_cast<jobject>(&methodParams),
reinterpret_cast<jlongArray>(filterptr),
input.filterIdType,
nullptr);

//Then
int actualEfSearch = ivf_id_map_.paramsCalled->nprobe;
// Asserting the captured argument
EXPECT_EQ(input.k, ivf_id_map_.kCalled);
EXPECT_EQ(expectedNprobe, actualEfSearch);
if (input.parentIdsPresent) {
faiss::IDGrouper *grouper = ivf_id_map_.paramsCalled->grp;
EXPECT_TRUE(grouper != nullptr);
}
ivf_id_map_.resetMock();
}

INSTANTIATE_TEST_CASE_P(
QueryIndexIVFTest,
FaissWrapperIVFQueryTestFixture,
::testing::Values(
QueryIndexInput{"algoParams present, parent absent", 10, 0, false, false, -1, 200 },
QueryIndexInput{"algoParams present, parent absent", 10,0, false, false, -1, -1 },
QueryIndexInput{"algoParams present, parent present", 10, 0, true, true, -1, 200 },
QueryIndexInput{"algoParams absent, parent present", 10, 0, true, true, -1, -1 }
)
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.util.EngineSpecificMethodContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.util.EngineSpecificMethodContext;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
Expand All @@ -49,6 +49,7 @@
import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
Expand All @@ -72,6 +73,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE);
public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE);
public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH);
public static final ParseField NPROBE_FIELD = new ParseField(METHOD_PARAMETER_NPROBES);
public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER);
public static final int K_MAX = 10000;
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME;

/**
* Note: This parser is used by neural plugin as well, breaking changes will require changes in neural as well
*/
@EqualsAndHashCode
@Getter
@AllArgsConstructor
Expand Down
Loading
Loading