Skip to content

Commit

Permalink
big boi (langchain4j#25)
Browse files Browse the repository at this point in the history
Sorry for a huge PR...
- added retries to OpenAiChatModel
- added @UserName: an option to define a name of a user as a parameter
in AI Services API
- added an option to split multiple documents at once (see
DocumentSplitter)
- redesigned document loaders (see FileSystemDocumentLoader)
- renamed DocumentSegment into TextSegment
- redesigned ConversationalRetrievalChain
- added EmbeddingStoreIngestor
- misc refactorings/fixes

---------

Co-authored-by: deep-learning-dynamo <deep-learning-dynamo@gmail.com>
  • Loading branch information
langchain4j and deep-learning-dynamo authored Jul 15, 2023
1 parent cbc4462 commit 755c9d0
Show file tree
Hide file tree
Showing 49 changed files with 869 additions and 411 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.langchain4j.data.document;

import dev.langchain4j.data.segment.TextSegment;

import java.util.Objects;

public class Document {
Expand All @@ -20,8 +22,8 @@ public Metadata metadata() {
return metadata;
}

public DocumentSegment toDocumentSegment() {
return DocumentSegment.from(text, metadata);
public TextSegment toTextSegment() {
return TextSegment.from(text, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
package dev.langchain4j.data.document;

import dev.langchain4j.data.segment.TextSegment;

import java.util.List;

import static java.util.stream.Collectors.toList;

public interface DocumentSplitter {

List<DocumentSegment> split(Document document);
List<TextSegment> split(Document document);

default List<TextSegment> split(List<Document> documents) {
return documents.stream()
.flatMap(document -> split(document).stream())
.collect(toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,39 @@

public class UserMessage extends ChatMessage {

private final String name;

public UserMessage(String text) {
this(null, text);
}

public UserMessage(String name, String text) {
super(text);
this.name = name;
}

public String name() {
return name;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
UserMessage that = (UserMessage) o;
return Objects.equals(this.text, that.text);
return Objects.equals(this.name, that.name)
&& Objects.equals(this.text, that.text);
}

@Override
public int hashCode() {
return Objects.hash(text);
return Objects.hash(name, text);
}

@Override
public String toString() {
return "UserMessage {" +
" name = \"" + name + "\"" +
" text = \"" + text + "\"" +
" }";
}
Expand All @@ -32,7 +45,15 @@ public static UserMessage from(String text) {
return new UserMessage(text);
}

public static UserMessage from(String name, String text) {
return new UserMessage(name, text);
}

public static UserMessage userMessage(String text) {
return from(text);
}

public static UserMessage userMessage(String name, String text) {
return from(name, text);
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
package dev.langchain4j.data.document;
package dev.langchain4j.data.segment;


import dev.langchain4j.data.document.Metadata;

import java.util.Objects;

public class DocumentSegment {
/**
* Represents a semantically meaningful segment (chunk/piece/fragment) of a larger entity such as a document or chat conversation.
* This might be a sentence, a paragraph, or any other discrete unit of text that carries meaning.
* This class encapsulates a piece of text and its associated metadata.
*/
public class TextSegment {

private final String text;
private final Metadata metadata;

public DocumentSegment(String text, Metadata metadata) {
public TextSegment(String text, Metadata metadata) {
this.text = text;
this.metadata = metadata;
}
Expand All @@ -25,7 +32,7 @@ public Metadata metadata() {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DocumentSegment that = (DocumentSegment) o;
TextSegment that = (TextSegment) o;
return Objects.equals(this.text, that.text)
&& Objects.equals(this.metadata, that.metadata);
}
Expand All @@ -37,25 +44,25 @@ public int hashCode() {

@Override
public String toString() {
return "DocumentSegment {" +
return "TextSegment {" +
" text = \"" + text + "\"" +
" metadata = \"" + metadata + "\"" +
" }";
}

public static DocumentSegment from(String text) {
return new DocumentSegment(text, new Metadata());
public static TextSegment from(String text) {
return new TextSegment(text, new Metadata());
}

public static DocumentSegment from(String text, Metadata metadata) {
return new DocumentSegment(text, metadata);
public static TextSegment from(String text, Metadata metadata) {
return new TextSegment(text, metadata);
}

public static DocumentSegment documentSegment(String text) {
public static TextSegment textSegment(String text) {
return from(text);
}

public static DocumentSegment documentSegment(String text, Metadata metadata) {
public static TextSegment textSegment(String text, Metadata metadata) {
return from(text, metadata);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package dev.langchain4j.model.chat;

import dev.langchain4j.MightChangeInTheFuture;
import dev.langchain4j.data.document.DocumentSegment;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.input.Prompt;
Expand All @@ -22,5 +22,5 @@ public interface TokenCountEstimator {

int estimateTokenCount(List<ChatMessage> messages);

int estimateTokenCount(DocumentSegment documentSegment);
int estimateTokenCount(TextSegment textSegment);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package dev.langchain4j.model.embedding;

import dev.langchain4j.data.document.DocumentSegment;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Result;

Expand All @@ -10,7 +10,7 @@ public interface EmbeddingModel {

Result<Embedding> embed(String text);

Result<Embedding> embed(DocumentSegment documentSegment);
Result<Embedding> embed(TextSegment textSegment);

Result<List<Embedding>> embedAll(List<DocumentSegment> documentSegments);
Result<List<Embedding>> embedAll(List<TextSegment> textSegments);
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package dev.langchain4j.model.embedding;

import dev.langchain4j.data.document.DocumentSegment;
import dev.langchain4j.data.segment.TextSegment;

import java.util.List;

public interface TokenCountEstimator {

int estimateTokenCount(String text);

int estimateTokenCount(DocumentSegment documentSegment);
int estimateTokenCount(TextSegment textSegment);

int estimateTokenCount(List<DocumentSegment> documentSegments);
int estimateTokenCount(List<TextSegment> textSegments);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package dev.langchain4j.model.language;

import dev.langchain4j.data.document.DocumentSegment;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.input.Prompt;

public interface TokenCountEstimator {
Expand All @@ -11,5 +11,5 @@ public interface TokenCountEstimator {

int estimateTokenCount(Object structuredPrompt);

int estimateTokenCount(DocumentSegment documentSegment);
int estimateTokenCount(TextSegment textSegment);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package dev.langchain4j.model.moderation;

import dev.langchain4j.data.document.DocumentSegment;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.output.Result;
Expand All @@ -19,5 +19,5 @@ public interface ModerationModel {

Result<Moderation> moderate(List<ChatMessage> messages);

Result<Moderation> moderate(DocumentSegment documentSegment);
Result<Moderation> moderate(TextSegment textSegment);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import dev.langchain4j.data.document.DocumentSegment;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import io.pinecone.PineconeClient;
import io.pinecone.PineconeClientConfig;
import io.pinecone.PineconeConnection;
Expand All @@ -19,10 +19,10 @@
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;

public class PineconeEmbeddingStoreImpl implements EmbeddingStore<DocumentSegment> {
public class PineconeEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {

private static final String DEFAULT_NAMESPACE = "default"; // do not change, will break backward compatibility!
private static final String METADATA_DOCUMENT_SEGMENT_TEXT = "document_segment_text"; // do not change, will break backward compatibility!
private static final String METADATA_TEXT_SEGMENT = "text_segment"; // do not change, will break backward compatibility!

private final PineconeConnection connection;
private final String nameSpace;
Expand Down Expand Up @@ -57,9 +57,9 @@ public void add(String id, Embedding embedding) {
}

@Override
public String add(Embedding embedding, DocumentSegment documentSegment) {
public String add(Embedding embedding, TextSegment textSegment) {
String id = generateRandomId(embedding);
addInternal(id, embedding, documentSegment);
addInternal(id, embedding, textSegment);
return id;
}

Expand All @@ -76,22 +76,22 @@ public List<String> addAll(List<Embedding> embeddings) {
}

@Override
public List<String> addAll(List<Embedding> embeddings, List<DocumentSegment> documentSegments) {
public List<String> addAll(List<Embedding> embeddings, List<TextSegment> textSegments) {

List<String> ids = embeddings.stream()
.map(PineconeEmbeddingStoreImpl::generateRandomId)
.collect(toList());

addAllInternal(ids, embeddings, documentSegments);
addAllInternal(ids, embeddings, textSegments);

return ids;
}

private void addInternal(String id, Embedding embedding, DocumentSegment documentSegment) {
addAllInternal(singletonList(id), singletonList(embedding), documentSegment == null ? null : singletonList(documentSegment));
private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
addAllInternal(singletonList(id), singletonList(embedding), textSegment == null ? null : singletonList(textSegment));
}

private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<DocumentSegment> documentSegments) {
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {

UpsertRequest.Builder upsertRequestBuilder = UpsertRequest.newBuilder()
.setNamespace(nameSpace);
Expand All @@ -100,11 +100,11 @@ private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<D

String id = ids.get(i);
Embedding embedding = embeddings.get(i);
DocumentSegment documentSegment = documentSegments.get(i);
TextSegment textSegment = textSegments.get(i);

Struct vectorMetadata = Struct.newBuilder()
.putFields(METADATA_DOCUMENT_SEGMENT_TEXT, Value.newBuilder()
.setStringValue(documentSegment.text())
.putFields(METADATA_TEXT_SEGMENT, Value.newBuilder()
.setStringValue(textSegment.text())
.build())
.build();

Expand All @@ -121,12 +121,12 @@ private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<D
}

@Override
public List<EmbeddingMatch<DocumentSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
return findRelevant(referenceEmbedding, maxResults, -1); // TODO check -1
}

@Override
public List<EmbeddingMatch<DocumentSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minSimilarity) {
public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minSimilarity) {

QueryVector queryVector = QueryVector
.newBuilder()
Expand Down Expand Up @@ -167,24 +167,24 @@ public List<EmbeddingMatch<DocumentSegment>> findRelevant(Embedding referenceEmb
.collect(toList());
}

private static EmbeddingMatch<DocumentSegment> toEmbeddingMatch(Vector vector) {
Value documentSegmentTextValue = vector.getMetadata()
private static EmbeddingMatch<TextSegment> toEmbeddingMatch(Vector vector) {
Value textSegmentValue = vector.getMetadata()
.getFieldsMap()
.get(METADATA_DOCUMENT_SEGMENT_TEXT);
.get(METADATA_TEXT_SEGMENT);

return new EmbeddingMatch<>(
vector.getId(),
Embedding.from(vector.getValuesList()),
createDocumentSegmentIfExists(documentSegmentTextValue),
createTextSegmentIfExists(textSegmentValue),
null); // TODO
}

private static DocumentSegment createDocumentSegmentIfExists(Value documentSegmentTextValue) {
if (documentSegmentTextValue == null) {
private static TextSegment createTextSegmentIfExists(Value textSegmentValue) {
if (textSegmentValue == null) {
return null;
}

return DocumentSegment.from(documentSegmentTextValue.getStringValue());
return TextSegment.from(textSegmentValue.getStringValue());
}

private static String generateRandomId(Embedding embedding) {
Expand Down
Loading

0 comments on commit 755c9d0

Please sign in to comment.