Skip to content

Commit ecd432e

Browse files
authored
Merge pull request #22 from schemacrawler/full-text
Full text search for models that do not support embedding
2 parents 501a2c8 + ba4dd94 commit ecd432e

File tree

8 files changed

+106
-26
lines changed

8 files changed

+106
-26
lines changed

pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@
8181
<artifactId>jackson-module-jsonSchema</artifactId>
8282
</dependency>
8383

84+
<dependency>
85+
<groupId>us.fatehi</groupId>
86+
<artifactId>full-text-search</artifactId>
87+
<version>0.0.6</version>
88+
</dependency>
89+
8490
<dependency>
8591
<groupId>org.apache.commons</groupId>
8692
<artifactId>commons-math3</artifactId>

src/main/assembly/assembly.xml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,20 @@
1818
<outputDirectory>.</outputDirectory>
1919
<useProjectArtifact>false</useProjectArtifact>
2020
<unpack>false</unpack>
21+
<useTransitiveFiltering>false</useTransitiveFiltering>
2122
<scope>runtime</scope>
2223
<excludes>
2324
<exclude>org.slf4j:*</exclude>
24-
<exclude>commons-dbutils:*</exclude>
25-
<exclude>com.fasterxml.jackson.core:*</exclude>
25+
<exclude>org.junit.platform:*</exclude>
26+
<exclude>org.junit.jupiter:*</exclude>
27+
<exclude>org.opentest4j:*</exclude>
28+
<exclude>com.fasterxml.jackson.*:*</exclude>
29+
<exclude>org.yaml:*</exclude>
30+
<exclude>javax.validation:*</exclude>
31+
<exclude>com.azure:azure-core-test</exclude>
32+
<exclude>org.apache.ant:*</exclude>
2633
<exclude>us.fatehi:*</exclude>
2734
</excludes>
28-
</dependencySet>
35+
</dependencySet>
2936
</dependencySets>
3037
</assembly>

src/main/java/schemacrawler/tools/command/aichat/utility/lanchain4j/AiModelFactoryUtility.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ public class AiModelFactoryUtility {
4040

4141
interface AiModelFactory {
4242

43+
boolean hasEmbeddingModel();
44+
4345
boolean isSupported();
4446

4547
ChatLanguageModel newChatLanguageModel();

src/main/java/schemacrawler/tools/command/aichat/utility/lanchain4j/AnthropicModelFactory.java

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,14 @@
2929
package schemacrawler.tools.command.aichat.utility.lanchain4j;
3030

3131
import java.time.Duration;
32-
import java.util.Arrays;
33-
import java.util.Collections;
34-
import java.util.List;
3532
import static java.util.Objects.requireNonNull;
36-
import dev.langchain4j.data.embedding.Embedding;
37-
import dev.langchain4j.data.segment.TextSegment;
33+
import dev.langchain4j.exception.UnsupportedFeatureException;
3834
import dev.langchain4j.memory.ChatMemory;
3935
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
4036
import dev.langchain4j.model.anthropic.AnthropicChatModel;
4137
import dev.langchain4j.model.anthropic.AnthropicChatModelName;
4238
import dev.langchain4j.model.chat.ChatLanguageModel;
4339
import dev.langchain4j.model.embedding.EmbeddingModel;
44-
import dev.langchain4j.model.output.Response;
4540
import schemacrawler.tools.command.aichat.options.AiChatCommandOptions;
4641
import schemacrawler.tools.command.aichat.utility.lanchain4j.AiModelFactoryUtility.AiModelFactory;
4742
import us.fatehi.utility.property.PropertyName;
@@ -55,6 +50,11 @@ public AnthropicModelFactory(final AiChatCommandOptions commandOptions) {
5550
aiChatCommandOptions = requireNonNull(commandOptions, "No AI Chat options provided");
5651
}
5752

53+
@Override
54+
public boolean hasEmbeddingModel() {
55+
return false;
56+
}
57+
5858
@Override
5959
public boolean isSupported() {
6060
if (!aiChatCommandOptions.aiProvider().equals(aiProvider.getName())) {
@@ -88,20 +88,7 @@ public ChatMemory newChatMemory() {
8888

8989
@Override
9090
public EmbeddingModel newEmbeddingModel() {
91-
return new EmbeddingModel() {
92-
93-
@Override
94-
public Response<List<Embedding>> embedAll(final List<TextSegment> textSegments) {
95-
if (textSegments == null || textSegments.isEmpty()) {
96-
return new Response<>(Collections.emptyList());
97-
}
98-
99-
final Embedding embedding = new Embedding(new float[] {0f});
100-
final Embedding[] embeddings = new Embedding[textSegments.size()];
101-
Arrays.fill(embeddings, embedding);
102-
return new Response<>(Arrays.asList(embeddings));
103-
}
104-
};
91+
throw new UnsupportedFeatureException("Anthropic does not have embedding models");
10592
}
10693

10794
@Override
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package schemacrawler.tools.command.aichat.utility.lanchain4j;
2+
3+
import java.util.List;
4+
import java.util.logging.Logger;
5+
import org.apache.lucene.store.Directory;
6+
import static java.util.Objects.requireNonNull;
7+
import dev.langchain4j.data.document.Metadata;
8+
import dev.langchain4j.data.segment.TextSegment;
9+
import dev.langchain4j.rag.content.Content;
10+
import dev.langchain4j.rag.content.retriever.ContentRetriever;
11+
import dev.langchain4j.rag.query.Query;
12+
import schemacrawler.schema.Catalog;
13+
import schemacrawler.schema.DatabaseInfo;
14+
import schemacrawler.schema.Table;
15+
import schemacrawler.tools.command.serialize.model.CompactCatalogUtility;
16+
import schemacrawler.tools.command.serialize.model.TableDocument;
17+
import us.fatehi.search.DirectoryFactory;
18+
import us.fatehi.search.LuceneContentRetriever;
19+
import us.fatehi.search.LuceneIndexer;
20+
21+
public class FullTextCatalogContentRetriever implements ContentRetriever {
22+
23+
private static final Logger LOGGER =
24+
Logger.getLogger(Langchain4JChatAssistant.class.getCanonicalName());
25+
26+
private final LuceneContentRetriever fullTextCatalogRetriever;
27+
28+
public FullTextCatalogContentRetriever(final Catalog catalog) {
29+
requireNonNull(catalog, "No catalog provided");
30+
31+
final Directory tempDirectory = DirectoryFactory.tempDirectory();
32+
final LuceneIndexer luceneIndexer = new LuceneIndexer(tempDirectory);
33+
final TextSegment databaseInfoContent = getDatabaseInfoContent(catalog);
34+
luceneIndexer.addContent(databaseInfoContent);
35+
for (final Table table : catalog.getTables()) {
36+
final TableDocument tableDocument = CompactCatalogUtility.getTableDocument(table, false);
37+
luceneIndexer.addContent(TextSegment.from(tableDocument.toJson()));
38+
}
39+
40+
fullTextCatalogRetriever =
41+
LuceneContentRetriever.builder()
42+
.directory(tempDirectory)
43+
.matchUntilTopN()
44+
.maxTokens(5_000)
45+
.build();
46+
}
47+
48+
@Override
49+
public List<Content> retrieve(final Query query) {
50+
return fullTextCatalogRetriever.retrieve(query);
51+
}
52+
53+
private TextSegment getDatabaseInfoContent(final Catalog catalog) {
54+
final DatabaseInfo databaseInfo = catalog.getDatabaseInfo();
55+
final String databaseProductName = databaseInfo.getDatabaseProductName();
56+
final Metadata metadata = new Metadata();
57+
metadata.put("database", databaseProductName);
58+
metadata.put("database-version", databaseInfo.getDatabaseProductVersion());
59+
final TextSegment textSegment =
60+
TextSegment.from(
61+
String.format("Customize SQL queries for %s", databaseProductName), metadata);
62+
return textSegment;
63+
}
64+
}

src/main/java/schemacrawler/tools/command/aichat/utility/lanchain4j/GitHubModelFactory.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ public GitHubModelFactory(final AiChatCommandOptions commandOptions) {
5252
aiChatCommandOptions = requireNonNull(commandOptions, "No AI Chat options provided");
5353
}
5454

55+
@Override
56+
public boolean hasEmbeddingModel() {
57+
return true;
58+
}
59+
5560
@Override
5661
public boolean isSupported() {
5762
if (!aiChatCommandOptions.aiProvider().equals(aiProvider.getName())) {

src/main/java/schemacrawler/tools/command/aichat/utility/lanchain4j/Langchain4JChatAssistant.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,12 @@ public Langchain4JChatAssistant(
9797

9898
final boolean useMetadata = aiChatOptions.useMetadata();
9999
if (useMetadata) {
100-
final EmbeddingModel embeddingModel = modelFactory.newEmbeddingModel();
101-
contentRetriever = new CatalogContentRetriever(embeddingModel, catalog);
100+
if (modelFactory.hasEmbeddingModel()) {
101+
final EmbeddingModel embeddingModel = modelFactory.newEmbeddingModel();
102+
contentRetriever = new CatalogContentRetriever(embeddingModel, catalog);
103+
} else {
104+
contentRetriever = new FullTextCatalogContentRetriever(catalog);
105+
}
102106
} else {
103107
contentRetriever = query -> Collections.emptyList();
104108
}
@@ -171,7 +175,7 @@ private SystemMessage createSystemMessage(final String prompt) {
171175
buffer.append(metadataPriming).append("\n");
172176
final List<Content> contents = contentRetriever.retrieve(Query.from(prompt));
173177
for (final Content content : contents) {
174-
buffer.append(content.textSegment().text()).append("\n");
178+
buffer.append("\n").append(content.textSegment().text());
175179
}
176180

177181
return SystemMessage.from(buffer.toString());

src/main/java/schemacrawler/tools/command/aichat/utility/lanchain4j/OpenAIModelFactory.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ public OpenAIModelFactory(final AiChatCommandOptions commandOptions) {
5353
aiChatCommandOptions = requireNonNull(commandOptions, "No AI Chat options provided");
5454
}
5555

56+
@Override
57+
public boolean hasEmbeddingModel() {
58+
return true;
59+
}
60+
5661
@Override
5762
public boolean isSupported() {
5863
if (!aiChatCommandOptions.aiProvider().equals(aiProvider.getName())) {

0 commit comments

Comments
 (0)