Skip to content

Commit acfed87

Browse files
authored
Merge pull request #5 from postgresml/silas-patch-js-rerank
Patched re-ranking for JavaScript
2 parents 17a2f7f + 616ead8 commit acfed87

File tree

6 files changed

+106
-10
lines changed

6 files changed

+106
-10
lines changed

korvus/javascript/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "korvus",
3-
"version": "1.1.2",
3+
"version": "1.1.3",
44
"description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone",
55
"keywords": [
66
"postgres",

korvus/javascript/tests/typescript-tests/test.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,46 @@ it("can vector search with query builder", async () => {
164164
await collection.archive();
165165
});
166166

167+
it("can vector search with re-ranking", async () => {
168+
let pipeline = korvus.newPipeline("1", {
169+
title: {
170+
semantic_search: { model: "intfloat/e5-small-v2", parameters: { prompt: "passage: " } },
171+
full_text_search: { configuration: "english" },
172+
},
173+
body: {
174+
splitter: { model: "recursive_character" },
175+
semantic_search: {
176+
model: "text-embedding-ada-002",
177+
source: "openai",
178+
},
179+
},
180+
});
181+
let collection = korvus.newCollection("test_j_c_cvswr_0")
182+
await collection.add_pipeline(pipeline)
183+
await collection.upsert_documents(generate_dummy_documents(5))
184+
let results = await collection.vector_search(
185+
{
186+
query: {
187+
fields: {
188+
title: { query: "Test document: 2", parameters: { prompt: "query: " }, full_text_filter: "test" },
189+
body: { query: "Test document: 2" },
190+
},
191+
filter: { id: { "$gt": 2 } },
192+
},
193+
rerank: {
194+
model: "mixedbread-ai/mxbai-rerank-base-v1",
195+
query: "Test query",
196+
num_documents_to_rerank: 100
197+
},
198+
limit: 5,
199+
},
200+
pipeline,
201+
);
202+
let ids = results.map(r => r["document"]["id"]);
203+
expect(ids).toEqual([4, 3, 3, 4]);
204+
await collection.archive();
205+
});
206+
167207
///////////////////////////////////////////////////
168208
// Test rag ///////////////////////////////////////
169209
///////////////////////////////////////////////////

korvus/python/tests/test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,60 @@ async def test_can_vector_search_with_query_builder():
212212
await collection.archive()
213213

214214

215+
@pytest.mark.asyncio
216+
async def test_can_vector_search_with_rerank():
217+
pipeline = korvus.Pipeline(
218+
"test_p_p_tcvswr_0",
219+
{
220+
"title": {
221+
"semantic_search": {
222+
"model": "intfloat/e5-small-v2",
223+
"parameters": {"prompt": "passage: "},
224+
},
225+
"full_text_search": {"configuration": "english"},
226+
},
227+
"text": {
228+
"splitter": {"model": "recursive_character"},
229+
"semantic_search": {
230+
"model": "intfloat/e5-small-v2",
231+
"parameters": {"prompt": "passage: "},
232+
},
233+
},
234+
},
235+
)
236+
collection = korvus.Collection("test_p_c_tcvs_3")
237+
await collection.add_pipeline(pipeline)
238+
await collection.upsert_documents(generate_dummy_documents(5))
239+
results = await collection.vector_search(
240+
{
241+
"query": {
242+
"fields": {
243+
"title": {
244+
"query": "Test document: 2",
245+
"parameters": {"prompt": "passage: "},
246+
"full_text_filter": "test",
247+
},
248+
"text": {
249+
"query": "Test document: 2",
250+
"parameters": {"prompt": "passage: "},
251+
},
252+
},
253+
"filter": {"id": {"$gt": 2}},
254+
},
255+
"rerank": {
256+
"model": "mixedbread-ai/mxbai-rerank-base-v1",
257+
"query": "Test query",
258+
"num_documents_to_rerank": 100,
259+
},
260+
"limit": 5,
261+
},
262+
pipeline,
263+
)
264+
ids = [result["document"]["id"] for result in results]
265+
assert ids == [3, 3, 4, 4]
266+
await collection.archive()
267+
268+
215269
###################################################
216270
## Test RAG #######################################
217271
###################################################

korvus/src/collection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ impl Collection {
345345
let pool = get_or_initialize_pool(&self.database_url).await?;
346346
let pipelines_table_name = format!("{}.pipelines", project_info.name);
347347
let exists: bool = sqlx::query_scalar(&query_builder!(
348-
"SELECT EXISTS (SELECT id FROM %s WHERE name = $1 AND active = TRUE)",
348+
"SELECT EXISTS (SELECT id FROM %s WHERE name = $1)",
349349
pipelines_table_name
350350
))
351351
.bind(&pipeline.name)

korvus/src/lib.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ mod tests {
610610
#[tokio::test]
611611
async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> {
612612
internal_init_logger(None, None).ok();
613-
let collection_name = "test_r_c_cudaep_43";
613+
let collection_name = "test_r_c_cudaep_44";
614614
let mut collection = Collection::new(collection_name, None)?;
615615
let pipeline_name = "0";
616616
let mut pipeline = Pipeline::new(
@@ -654,7 +654,7 @@ mod tests {
654654
#[tokio::test]
655655
async fn random_pipelines_documents_test() -> anyhow::Result<()> {
656656
internal_init_logger(None, None).ok();
657-
let collection_name = "test_r_c_rpdt_3";
657+
let collection_name = "test_r_c_rpdt_4";
658658
let mut collection = Collection::new(collection_name, None)?;
659659
let documents = generate_dummy_documents(6);
660660
collection
@@ -818,7 +818,7 @@ mod tests {
818818
#[tokio::test]
819819
async fn pipeline_sync_status() -> anyhow::Result<()> {
820820
internal_init_logger(None, None).ok();
821-
let collection_name = "test_r_c_pss_6";
821+
let collection_name = "test_r_c_pss_7";
822822
let mut collection = Collection::new(collection_name, None)?;
823823
let pipeline_name = "0";
824824
let mut pipeline = Pipeline::new(
@@ -1140,7 +1140,7 @@ mod tests {
11401140
#[tokio::test]
11411141
async fn can_search_with_remote_embeddings() -> anyhow::Result<()> {
11421142
internal_init_logger(None, None).ok();
1143-
let collection_name = "test r_c_cswre_66";
1143+
let collection_name = "test r_c_cswre_67";
11441144
let mut collection = Collection::new(collection_name, None)?;
11451145
let documents = generate_dummy_documents(10);
11461146
collection.upsert_documents(documents.clone(), None).await?;
@@ -1314,7 +1314,7 @@ mod tests {
13141314
#[tokio::test]
13151315
async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> {
13161316
internal_init_logger(None, None).ok();
1317-
let collection_name = "test r_c_cvswre_7";
1317+
let collection_name = "test r_c_cvswre_8";
13181318
let mut collection = Collection::new(collection_name, None)?;
13191319
let documents = generate_dummy_documents(10);
13201320
collection.upsert_documents(documents.clone(), None).await?;
@@ -1455,7 +1455,7 @@ mod tests {
14551455
async fn can_vector_search_with_local_embeddings_and_specify_document_keys(
14561456
) -> anyhow::Result<()> {
14571457
internal_init_logger(None, None).ok();
1458-
let collection_name = "test r_c_cvswleasdk_1";
1458+
let collection_name = "test r_c_cvswleasdk_2";
14591459
let mut collection = Collection::new(collection_name, None)?;
14601460
let documents = generate_dummy_documents(2);
14611461
collection.upsert_documents(documents.clone(), None).await?;
@@ -1556,7 +1556,7 @@ mod tests {
15561556
#[tokio::test]
15571557
async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> {
15581558
internal_init_logger(None, None).ok();
1559-
let collection_name = "test r_c_cvswlear_1";
1559+
let collection_name = "test r_c_cvswlear_2";
15601560
let mut collection = Collection::new(collection_name, None)?;
15611561
let documents = generate_dummy_documents(10);
15621562
collection.upsert_documents(documents.clone(), None).await?;

korvus/src/vector_search_query_builder.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ const fn default_num_documents_to_rerank() -> u64 {
4545
10
4646
}
4747

48+
#[serde_as]
4849
#[derive(Debug, Deserialize, Serialize, Clone)]
4950
#[serde(deny_unknown_fields)]
5051
struct ValidRerank {
5152
query: String,
5253
model: String,
5354
#[serde(default = "default_num_documents_to_rerank")]
55+
#[serde_as(as = "FromInto<CustomU64Convertor>")]
5456
num_documents_to_rerank: u64,
5557
parameters: Option<Json>,
5658
}
@@ -61,7 +63,7 @@ const fn default_limit() -> u64 {
6163

6264
#[serde_as]
6365
#[derive(Debug, Deserialize, Serialize, Clone)]
64-
// #[serde(deny_unknown_fields)]
66+
#[serde(deny_unknown_fields)]
6567
pub struct ValidQuery {
6668
query: ValidQueryActions,
6769
// Need this when coming from JavaScript as everything is an f64 from JS

0 commit comments

Comments
 (0)