Skip to content

Commit 47d9cdc

Browse files
committed
Added hybrid search example [skip ci]
1 parent ef1ebb5 commit 47d9cdc

File tree

3 files changed

+189
-0
lines changed

3 files changed

+189
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Or check out some examples:
3838
- [Embeddings](examples/openai/src/main/java/com/example/Example.java) with OpenAI
3939
- [Binary embeddings](examples/cohere/src/main/java/com/example/Example.java) with Cohere
4040
- [Sentence embeddings](examples/djl/src/main/java/com/example/Example.java) with Deep Java Library
41+
- [Hybrid search](examples/hybrid/src/main/java/com/example/Example.java) with Deep Java Library (Reciprocal Rank Fusion)
4142
- [Horizontal scaling](examples/citus/src/main/java/com/example/Example.java) with Citus
4243
- [Bulk loading](examples/loading/src/main/java/com/example/Example.java) with `COPY`
4344

examples/hybrid/pom.xml

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
<?xml version='1.0' encoding='UTF-8'?>
2+
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://maven.apache.org/POM/4.0.0">
3+
<modelVersion>4.0.0</modelVersion>
4+
<groupId>com.example</groupId>
5+
<artifactId>example</artifactId>
6+
<version>1</version>
7+
<properties>
8+
<maven.compiler.release>17</maven.compiler.release>
9+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
10+
</properties>
11+
<dependencyManagement>
12+
<dependencies>
13+
<dependency>
14+
<groupId>ai.djl</groupId>
15+
<artifactId>bom</artifactId>
16+
<version>0.29.0</version>
17+
<type>pom</type>
18+
<scope>import</scope>
19+
</dependency>
20+
</dependencies>
21+
</dependencyManagement>
22+
<dependencies>
23+
<dependency>
24+
<groupId>org.postgresql</groupId>
25+
<artifactId>postgresql</artifactId>
26+
<version>42.7.3</version>
27+
</dependency>
28+
<dependency>
29+
<groupId>com.pgvector</groupId>
30+
<artifactId>pgvector</artifactId>
31+
<version>0.1.6</version>
32+
</dependency>
33+
<dependency>
34+
<groupId>ai.djl</groupId>
35+
<artifactId>api</artifactId>
36+
</dependency>
37+
<dependency>
38+
<groupId>ai.djl.huggingface</groupId>
39+
<artifactId>tokenizers</artifactId>
40+
</dependency>
41+
<dependency>
42+
<groupId>ai.djl.pytorch</groupId>
43+
<artifactId>pytorch-engine</artifactId>
44+
</dependency>
45+
<dependency>
46+
<groupId>org.slf4j</groupId>
47+
<artifactId>slf4j-nop</artifactId>
48+
<version>2.0.9</version>
49+
</dependency>
50+
</dependencies>
51+
<build>
52+
<plugins>
53+
<plugin>
54+
<artifactId>maven-assembly-plugin</artifactId>
55+
<version>3.7.1</version>
56+
<configuration>
57+
<descriptorRefs>
58+
<descriptorRef>jar-with-dependencies</descriptorRef>
59+
</descriptorRefs>
60+
<archive>
61+
<manifest>
62+
<mainClass>com.example.Example</mainClass>
63+
</manifest>
64+
</archive>
65+
<finalName>example</finalName>
66+
</configuration>
67+
<executions>
68+
<execution>
69+
<id>make-assembly</id>
70+
<phase>package</phase>
71+
<goals>
72+
<goal>single</goal>
73+
</goals>
74+
</execution>
75+
</executions>
76+
</plugin>
77+
</plugins>
78+
</build>
79+
</project>
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package com.example;
2+
3+
import java.io.IOException;
4+
import java.sql.Connection;
5+
import java.sql.DriverManager;
6+
import java.sql.PreparedStatement;
7+
import java.sql.ResultSet;
8+
import java.sql.SQLException;
9+
import java.sql.Statement;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import ai.djl.ModelException;
13+
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
14+
import ai.djl.inference.Predictor;
15+
import ai.djl.repository.zoo.Criteria;
16+
import ai.djl.repository.zoo.ZooModel;
17+
import ai.djl.translate.TranslateException;
18+
import com.pgvector.PGvector;
19+
20+
public class Example {
21+
public static void main(String[] args) throws IOException, ModelException, SQLException, TranslateException {
22+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
23+
24+
Statement setupStmt = conn.createStatement();
25+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
26+
setupStmt.executeUpdate("DROP TABLE IF EXISTS documents");
27+
28+
PGvector.addVectorType(conn);
29+
30+
Statement createStmt = conn.createStatement();
31+
createStmt.executeUpdate("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))");
32+
33+
ZooModel<String, float[]> model = loadModel("sentence-transformers/multi-qa-MiniLM-L6-cos-v1");
34+
35+
String[] input = {
36+
"The dog is barking",
37+
"The cat is purring",
38+
"The bear is growling"
39+
};
40+
List<float[]> embeddings = generateEmbeddings(model, input);
41+
42+
for (int i = 0; i < input.length; i++) {
43+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
44+
insertStmt.setString(1, input[i]);
45+
insertStmt.setObject(2, new PGvector(embeddings.get(i)));
46+
insertStmt.executeUpdate();
47+
}
48+
49+
String query = "growling bear";
50+
float[] queryEmbedding = generateEmbeddings(model, new String[] {query}).get(0);
51+
double k = 60;
52+
53+
PreparedStatement queryStmt = conn.prepareStatement(HYBRID_SQL);
54+
queryStmt.setObject(1, new PGvector(queryEmbedding));
55+
queryStmt.setObject(2, new PGvector(queryEmbedding));
56+
queryStmt.setString(3, query);
57+
queryStmt.setDouble(4, k);
58+
queryStmt.setDouble(5, k);
59+
ResultSet rs = queryStmt.executeQuery();
60+
while (rs.next()) {
61+
System.out.println(String.format("document: %d, RRF score: %f", rs.getLong("id"), rs.getDouble("score")));
62+
}
63+
64+
conn.close();
65+
}
66+
67+
private static ZooModel<String, float[]> loadModel(String id) throws IOException, ModelException {
68+
return Criteria.builder()
69+
.setTypes(String.class, float[].class)
70+
.optModelUrls("djl://ai.djl.huggingface.pytorch/" + id)
71+
.optEngine("PyTorch")
72+
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
73+
.build()
74+
.loadModel();
75+
}
76+
77+
private static List<float[]> generateEmbeddings(ZooModel<String, float[]> model, String[] input) throws TranslateException {
78+
Predictor<String, float[]> predictor = model.newPredictor();
79+
List<float[]> embeddings = new ArrayList<>(input.length);
80+
for (String text : input) {
81+
embeddings.add(predictor.predict(text));
82+
}
83+
return embeddings;
84+
}
85+
86+
public static final String HYBRID_SQL = """
87+
WITH semantic_search AS (
88+
SELECT id, RANK () OVER (ORDER BY embedding <=> ?) AS rank
89+
FROM documents
90+
ORDER BY embedding <=> ?
91+
LIMIT 20
92+
),
93+
keyword_search AS (
94+
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
95+
FROM documents, plainto_tsquery('english', ?) query
96+
WHERE to_tsvector('english', content) @@ query
97+
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
98+
LIMIT 20
99+
)
100+
SELECT
101+
COALESCE(semantic_search.id, keyword_search.id) AS id,
102+
COALESCE(1.0 / (? + semantic_search.rank), 0.0) +
103+
COALESCE(1.0 / (? + keyword_search.rank), 0.0) AS score
104+
FROM semantic_search
105+
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
106+
ORDER BY score DESC
107+
LIMIT 5
108+
""";
109+
}

0 commit comments

Comments
 (0)