-
Notifications
You must be signed in to change notification settings - Fork 5
/
Example.java
109 lines (95 loc) · 4.2 KB
/
Example.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package com.example;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import ai.djl.ModelException;
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import com.pgvector.PGvector;
public class Example {
public static void main(String[] args) throws IOException, ModelException, SQLException, TranslateException {
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
Statement setupStmt = conn.createStatement();
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
setupStmt.executeUpdate("DROP TABLE IF EXISTS documents");
PGvector.addVectorType(conn);
Statement createStmt = conn.createStatement();
createStmt.executeUpdate("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))");
ZooModel<String, float[]> model = loadModel("sentence-transformers/multi-qa-MiniLM-L6-cos-v1");
String[] input = {
"The dog is barking",
"The cat is purring",
"The bear is growling"
};
List<float[]> embeddings = generateEmbeddings(model, input);
for (int i = 0; i < input.length; i++) {
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
insertStmt.setString(1, input[i]);
insertStmt.setObject(2, new PGvector(embeddings.get(i)));
insertStmt.executeUpdate();
}
String query = "growling bear";
float[] queryEmbedding = generateEmbeddings(model, new String[] {query}).get(0);
double k = 60;
PreparedStatement queryStmt = conn.prepareStatement(HYBRID_SQL);
queryStmt.setObject(1, new PGvector(queryEmbedding));
queryStmt.setObject(2, new PGvector(queryEmbedding));
queryStmt.setString(3, query);
queryStmt.setDouble(4, k);
queryStmt.setDouble(5, k);
ResultSet rs = queryStmt.executeQuery();
while (rs.next()) {
System.out.println(String.format("document: %d, RRF score: %f", rs.getLong("id"), rs.getDouble("score")));
}
conn.close();
}
private static ZooModel<String, float[]> loadModel(String id) throws IOException, ModelException {
return Criteria.builder()
.setTypes(String.class, float[].class)
.optModelUrls("djl://ai.djl.huggingface.pytorch/" + id)
.optEngine("PyTorch")
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
.build()
.loadModel();
}
private static List<float[]> generateEmbeddings(ZooModel<String, float[]> model, String[] input) throws TranslateException {
Predictor<String, float[]> predictor = model.newPredictor();
List<float[]> embeddings = new ArrayList<>(input.length);
for (String text : input) {
embeddings.add(predictor.predict(text));
}
return embeddings;
}
public static final String HYBRID_SQL = """
WITH semantic_search AS (
SELECT id, RANK () OVER (ORDER BY embedding <=> ?) AS rank
FROM documents
ORDER BY embedding <=> ?
LIMIT 20
),
keyword_search AS (
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
FROM documents, plainto_tsquery('english', ?) query
WHERE to_tsvector('english', content) @@ query
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
LIMIT 20
)
SELECT
COALESCE(semantic_search.id, keyword_search.id) AS id,
COALESCE(1.0 / (? + semantic_search.rank), 0.0) +
COALESCE(1.0 / (? + keyword_search.rank), 0.0) AS score
FROM semantic_search
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
ORDER BY score DESC
LIMIT 5
""";
}