Skip to content

Patched re-ranking for JavaScript #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion korvus/javascript/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "korvus",
"version": "1.1.2",
"version": "1.1.3",
"description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone",
"keywords": [
"postgres",
Expand Down
40 changes: 40 additions & 0 deletions korvus/javascript/tests/typescript-tests/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,46 @@ it("can vector search with query builder", async () => {
await collection.archive();
});

it("can vector search with re-ranking", async () => {
let pipeline = korvus.newPipeline("1", {
title: {
semantic_search: { model: "intfloat/e5-small-v2", parameters: { prompt: "passage: " } },
full_text_search: { configuration: "english" },
},
body: {
splitter: { model: "recursive_character" },
semantic_search: {
model: "text-embedding-ada-002",
source: "openai",
},
},
});
let collection = korvus.newCollection("test_j_c_cvswr_0")
await collection.add_pipeline(pipeline)
await collection.upsert_documents(generate_dummy_documents(5))
let results = await collection.vector_search(
{
query: {
fields: {
title: { query: "Test document: 2", parameters: { prompt: "query: " }, full_text_filter: "test" },
body: { query: "Test document: 2" },
},
filter: { id: { "$gt": 2 } },
},
rerank: {
model: "mixedbread-ai/mxbai-rerank-base-v1",
query: "Test query",
num_documents_to_rerank: 100
},
limit: 5,
},
pipeline,
);
let ids = results.map(r => r["document"]["id"]);
expect(ids).toEqual([4, 3, 3, 4]);
await collection.archive();
});

///////////////////////////////////////////////////
// Test rag ///////////////////////////////////////
///////////////////////////////////////////////////
Expand Down
54 changes: 54 additions & 0 deletions korvus/python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,60 @@ async def test_can_vector_search_with_query_builder():
await collection.archive()


@pytest.mark.asyncio
async def test_can_vector_search_with_rerank():
pipeline = korvus.Pipeline(
"test_p_p_tcvswr_0",
{
"title": {
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {"prompt": "passage: "},
},
"full_text_search": {"configuration": "english"},
},
"text": {
"splitter": {"model": "recursive_character"},
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {"prompt": "passage: "},
},
},
},
)
collection = korvus.Collection("test_p_c_tcvs_3")
await collection.add_pipeline(pipeline)
await collection.upsert_documents(generate_dummy_documents(5))
results = await collection.vector_search(
{
"query": {
"fields": {
"title": {
"query": "Test document: 2",
"parameters": {"prompt": "passage: "},
"full_text_filter": "test",
},
"text": {
"query": "Test document: 2",
"parameters": {"prompt": "passage: "},
},
},
"filter": {"id": {"$gt": 2}},
},
"rerank": {
"model": "mixedbread-ai/mxbai-rerank-base-v1",
"query": "Test query",
"num_documents_to_rerank": 100,
},
"limit": 5,
},
pipeline,
)
ids = [result["document"]["id"] for result in results]
assert ids == [3, 3, 4, 4]
await collection.archive()


###################################################
## Test RAG #######################################
###################################################
Expand Down
2 changes: 1 addition & 1 deletion korvus/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ impl Collection {
let pool = get_or_initialize_pool(&self.database_url).await?;
let pipelines_table_name = format!("{}.pipelines", project_info.name);
let exists: bool = sqlx::query_scalar(&query_builder!(
"SELECT EXISTS (SELECT id FROM %s WHERE name = $1 AND active = TRUE)",
"SELECT EXISTS (SELECT id FROM %s WHERE name = $1)",
pipelines_table_name
))
.bind(&pipeline.name)
Expand Down
14 changes: 7 additions & 7 deletions korvus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ mod tests {
#[tokio::test]
async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_cudaep_43";
let collection_name = "test_r_c_cudaep_44";
let mut collection = Collection::new(collection_name, None)?;
let pipeline_name = "0";
let mut pipeline = Pipeline::new(
Expand Down Expand Up @@ -654,7 +654,7 @@ mod tests {
#[tokio::test]
async fn random_pipelines_documents_test() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_rpdt_3";
let collection_name = "test_r_c_rpdt_4";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(6);
collection
Expand Down Expand Up @@ -818,7 +818,7 @@ mod tests {
#[tokio::test]
async fn pipeline_sync_status() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_pss_6";
let collection_name = "test_r_c_pss_7";
let mut collection = Collection::new(collection_name, None)?;
let pipeline_name = "0";
let mut pipeline = Pipeline::new(
Expand Down Expand Up @@ -1140,7 +1140,7 @@ mod tests {
#[tokio::test]
async fn can_search_with_remote_embeddings() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cswre_66";
let collection_name = "test r_c_cswre_67";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down Expand Up @@ -1314,7 +1314,7 @@ mod tests {
#[tokio::test]
async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cvswre_7";
let collection_name = "test r_c_cvswre_8";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down Expand Up @@ -1455,7 +1455,7 @@ mod tests {
async fn can_vector_search_with_local_embeddings_and_specify_document_keys(
) -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cvswleasdk_1";
let collection_name = "test r_c_cvswleasdk_2";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(2);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down Expand Up @@ -1556,7 +1556,7 @@ mod tests {
#[tokio::test]
async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cvswlear_1";
let collection_name = "test r_c_cvswlear_2";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down
4 changes: 3 additions & 1 deletion korvus/src/vector_search_query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ const fn default_num_documents_to_rerank() -> u64 {
10
}

#[serde_as]
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(deny_unknown_fields)]
struct ValidRerank {
query: String,
model: String,
#[serde(default = "default_num_documents_to_rerank")]
#[serde_as(as = "FromInto<CustomU64Convertor>")]
num_documents_to_rerank: u64,
parameters: Option<Json>,
}
Expand All @@ -61,7 +63,7 @@ const fn default_limit() -> u64 {

#[serde_as]
#[derive(Debug, Deserialize, Serialize, Clone)]
// #[serde(deny_unknown_fields)]
#[serde(deny_unknown_fields)]
pub struct ValidQuery {
query: ValidQueryActions,
// Need this when coming from JavaScript as everything is an f64 from JS
Expand Down