Skip to content

Commit ef1ebb5

Browse files
committed
Added Deep Java Library example [skip ci]
1 parent 532832e commit ef1ebb5

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Or check out some examples:
3737

3838
- [Embeddings](examples/openai/src/main/java/com/example/Example.java) with OpenAI
3939
- [Binary embeddings](examples/cohere/src/main/java/com/example/Example.java) with Cohere
40+
- [Sentence embeddings](examples/djl/src/main/java/com/example/Example.java) with Deep Java Library
4041
- [Horizontal scaling](examples/citus/src/main/java/com/example/Example.java) with Citus
4142
- [Bulk loading](examples/loading/src/main/java/com/example/Example.java) with `COPY`
4243

examples/djl/pom.xml

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
<?xml version='1.0' encoding='UTF-8'?>
2+
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://maven.apache.org/POM/4.0.0">
3+
<modelVersion>4.0.0</modelVersion>
4+
<groupId>com.example</groupId>
5+
<artifactId>example</artifactId>
6+
<version>1</version>
7+
<properties>
8+
<maven.compiler.release>11</maven.compiler.release>
9+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
10+
</properties>
11+
<dependencyManagement>
12+
<dependencies>
13+
<dependency>
14+
<groupId>ai.djl</groupId>
15+
<artifactId>bom</artifactId>
16+
<version>0.29.0</version>
17+
<type>pom</type>
18+
<scope>import</scope>
19+
</dependency>
20+
</dependencies>
21+
</dependencyManagement>
22+
<dependencies>
23+
<dependency>
24+
<groupId>org.postgresql</groupId>
25+
<artifactId>postgresql</artifactId>
26+
<version>42.7.3</version>
27+
</dependency>
28+
<dependency>
29+
<groupId>com.pgvector</groupId>
30+
<artifactId>pgvector</artifactId>
31+
<version>0.1.6</version>
32+
</dependency>
33+
<dependency>
34+
<groupId>ai.djl</groupId>
35+
<artifactId>api</artifactId>
36+
</dependency>
37+
<dependency>
38+
<groupId>ai.djl.huggingface</groupId>
39+
<artifactId>tokenizers</artifactId>
40+
</dependency>
41+
<dependency>
42+
<groupId>ai.djl.pytorch</groupId>
43+
<artifactId>pytorch-engine</artifactId>
44+
</dependency>
45+
<dependency>
46+
<groupId>org.slf4j</groupId>
47+
<artifactId>slf4j-nop</artifactId>
48+
<version>2.0.9</version>
49+
</dependency>
50+
</dependencies>
51+
<build>
52+
<plugins>
53+
<plugin>
54+
<artifactId>maven-assembly-plugin</artifactId>
55+
<version>3.7.1</version>
56+
<configuration>
57+
<descriptorRefs>
58+
<descriptorRef>jar-with-dependencies</descriptorRef>
59+
</descriptorRefs>
60+
<archive>
61+
<manifest>
62+
<mainClass>com.example.Example</mainClass>
63+
</manifest>
64+
</archive>
65+
<finalName>example</finalName>
66+
</configuration>
67+
<executions>
68+
<execution>
69+
<id>make-assembly</id>
70+
<phase>package</phase>
71+
<goals>
72+
<goal>single</goal>
73+
</goals>
74+
</execution>
75+
</executions>
76+
</plugin>
77+
</plugins>
78+
</build>
79+
</project>
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package com.example;
2+
3+
import java.io.IOException;
4+
import java.sql.Connection;
5+
import java.sql.DriverManager;
6+
import java.sql.PreparedStatement;
7+
import java.sql.ResultSet;
8+
import java.sql.SQLException;
9+
import java.sql.Statement;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import ai.djl.ModelException;
13+
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
14+
import ai.djl.inference.Predictor;
15+
import ai.djl.repository.zoo.Criteria;
16+
import ai.djl.repository.zoo.ZooModel;
17+
import ai.djl.translate.TranslateException;
18+
import com.pgvector.PGvector;
19+
20+
public class Example {
21+
public static void main(String[] args) throws IOException, ModelException, SQLException, TranslateException {
22+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
23+
24+
Statement setupStmt = conn.createStatement();
25+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
26+
setupStmt.executeUpdate("DROP TABLE IF EXISTS documents");
27+
28+
PGvector.addVectorType(conn);
29+
30+
Statement createStmt = conn.createStatement();
31+
createStmt.executeUpdate("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))");
32+
33+
ZooModel<String, float[]> model = loadModel("sentence-transformers/all-MiniLM-L6-v2");
34+
35+
String[] input = {
36+
"The dog is barking",
37+
"The cat is purring",
38+
"The bear is growling"
39+
};
40+
List<float[]> embeddings = generateEmbeddings(model, input);
41+
42+
for (int i = 0; i < input.length; i++) {
43+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
44+
insertStmt.setString(1, input[i]);
45+
insertStmt.setObject(2, new PGvector(embeddings.get(i)));
46+
insertStmt.executeUpdate();
47+
}
48+
49+
long documentId = 2;
50+
PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM documents WHERE id != ? ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = ?) LIMIT 5");
51+
neighborStmt.setLong(1, documentId);
52+
neighborStmt.setLong(2, documentId);
53+
ResultSet rs = neighborStmt.executeQuery();
54+
while (rs.next()) {
55+
System.out.println(rs.getString("content"));
56+
}
57+
58+
conn.close();
59+
}
60+
61+
private static ZooModel<String, float[]> loadModel(String id) throws IOException, ModelException {
62+
return Criteria.builder()
63+
.setTypes(String.class, float[].class)
64+
.optModelUrls("djl://ai.djl.huggingface.pytorch/" + id)
65+
.optEngine("PyTorch")
66+
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
67+
.build()
68+
.loadModel();
69+
}
70+
71+
private static List<float[]> generateEmbeddings(ZooModel<String, float[]> model, String[] input) throws TranslateException {
72+
Predictor<String, float[]> predictor = model.newPredictor();
73+
List<float[]> embeddings = new ArrayList<>(input.length);
74+
for (String text : input) {
75+
embeddings.add(predictor.predict(text));
76+
}
77+
return embeddings;
78+
}
79+
}

0 commit comments

Comments
 (0)