Skip to content

Commit

Permalink
Foundation for advanced RAG (langchain4j#538)
Browse files Browse the repository at this point in the history
So far, LangChain4j had only a simple (a.k.a., naive) RAG
implementation: a single `Retriever` was invoked on each interaction
with the LLM, and all retrieved `TextSegments` were appended to the end
of the `UserMessage`. This approach was very limiting.

This PR introduces support for much more advanced RAG use cases. The
design and mental model are inspired by [this
article](https://blog.langchain.dev/deconstructing-rag/) and [this
paper](https://arxiv.org/abs/2312.10997), making it advisable to read
the article.

This PR introduces a `RetrievalAugmentor` interface responsible for
augmenting a `UserMessage` with relevant content before sending it to
the LLM. The `RetrievalAugmentor` can be used with both `AiServices` and
`ConversationalRetrievalChain`, as well as stand-alone.

A default implementation of `RetrievalAugmentor`
(`DefaultRetrievalAugmentor`) is provided with the library and is
suggested as a good starting point. However, users are not limited to it
and can have more freedom with their own custom implementations.

`DefaultRetrievalAugmentor` decomposes the entire RAG flow into more
granular steps and base components:
- `QueryTransformer`
- `QueryRouter`
- `ContentRetriever` (the old `Retriever` is now deprecated)
- `ContentAggregator`
- `ContentInjector`

This modular design aims to separate concerns and simplify development,
testing, and evaluation. Most (if not all) currently known and proven
RAG techniques can be represented as one or multiple base components
listed above.

Here is how the decomposed RAG flow can be visualized:

![advanced-rag](https://github.com/langchain4j/langchain4j/assets/132277850/b699077d-dabf-4768-a241-3fcd9ab0286c)

This mental and software model aims to simplify the thinking, reasoning,
and implementation of advanced RAG flows.

Each base component listed above has a sensible and simple default
implementation configured in `DefaultRetrievalAugmentor` by default but
can be overridden by more sophisticated implementations (provided by the
library out-of-the-box) as well as custom ones. The list of
implementations is expected to grow over time as we discover new
techniques and implement existing proven ones.

This PR also introduces out-of-the-box support for the following proven
RAG techniques:
- Query expansion
- Query compression
- Query routing using LLM
- [Reciprocal Rank
Fusion](https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking)
- Re-ranking ([Cohere Rerank](https://docs.cohere.com/docs/reranking)
integration is coming in a [separate
PR](langchain4j#539)).
  • Loading branch information
langchain4j authored Jan 26, 2024
1 parent 5980f67 commit 14fb985
Show file tree
Hide file tree
Showing 59 changed files with 4,813 additions and 151 deletions.
Binary file added docs/static/img/advanced-rag.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 48 additions & 1 deletion langchain4j-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
<packaging>jar</packaging>

<name>langchain4j-core</name>
<description>Core classes and interfaces of langchain4j</description>
<description>Core classes and interfaces of LangChain4j</description>

<dependencies>

Expand All @@ -34,6 +34,12 @@
<artifactId>slf4j-api</artifactId>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
Expand Down Expand Up @@ -64,6 +70,17 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand Down Expand Up @@ -108,6 +125,16 @@
<configuration>
<rules>
<rule>
<excludes>
<exclude>dev.langchain4j.rag</exclude>
<exclude>dev.langchain4j.rag.content</exclude>
<exclude>dev.langchain4j.rag.content.aggregator</exclude>
<exclude>dev.langchain4j.rag.content.injector</exclude>
<exclude>dev.langchain4j.rag.content.retriever</exclude>
<exclude>dev.langchain4j.rag.query</exclude>
<exclude>dev.langchain4j.rag.query.router</exclude>
<exclude>dev.langchain4j.rag.query.transformer</exclude>
</excludes>
<element>PACKAGE</element>
<limits>
<limit>
Expand All @@ -117,6 +144,26 @@
</limit>
</limits>
</rule>
<rule>
<includes>
<include>dev.langchain4j.rag</include>
<include>dev.langchain4j.rag.content</include>
<include>dev.langchain4j.rag.content.aggregator</include>
<include>dev.langchain4j.rag.content.injector</include>
<include>dev.langchain4j.rag.content.retriever</include>
<include>dev.langchain4j.rag.query</include>
<include>dev.langchain4j.rag.query.router</include>
<include>dev.langchain4j.rag.query.transformer</include>
</includes>
<element>PACKAGE</element>
<limits>
<limit>
<counter>INSTRUCTION</counter>
<value>COVEREDRATIO</value>
<minimum>0.80</minimum>
</limit>
</limits>
</rule>
</rules>
</configuration>
</execution>
Expand Down
18 changes: 18 additions & 0 deletions langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static java.net.HttpURLConnection.HTTP_OK;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Collections.unmodifiableList;

import java.io.ByteArrayOutputStream;
import java.io.InputStream;
Expand All @@ -10,6 +11,7 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.function.Supplier;

Expand Down Expand Up @@ -207,4 +209,20 @@ public static byte[] readBytes(String url) {
throw new RuntimeException(e);
}
}

/**
* Returns an (unmodifiable) copy of the provided list.
* Returns <code>null</code> if the provided list is <code>null</code>.
*
* @param list The list to copy.
* @param <T> Generic type of the list.
* @return The copy of the provided list.
*/
public static <T> List<T> copyIfNotNull(List<T> list) {
if (list == null) {
return null;
}

return unmodifiableList(list);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.langchain4j.internal;

import java.util.Collection;
import java.util.Map;
import java.util.Objects;

import static dev.langchain4j.internal.Exceptions.illegalArgument;
Expand Down Expand Up @@ -67,6 +68,24 @@ public static <T extends Collection<?>> T ensureNotEmpty(T collection, String na
return collection;
}

/**
* Ensures that the given map is not null and not empty.
*
* @param map The map to check.
* @param name The name of the map to be used in the exception message.
* @param <K> The type of the key.
* @param <V> The type of the value.
* @return The map if it is not null and not empty.
* @throws IllegalArgumentException if the collection is null or empty.
*/
public static <K, V> Map<K, V> ensureNotEmpty(Map<K, V> map, String name) {
if (map == null || map.isEmpty()) {
throw illegalArgument("%s cannot be null or empty", name);
}

return map;
}

/**
* Ensures that the given string is not null and not blank.
* @param string The string to check.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.HashMap;
import java.util.Map;

import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Collections.singletonMap;

Expand All @@ -29,6 +30,7 @@ public class PromptTemplate {
static final String CURRENT_TIME = "current_time";
static final String CURRENT_DATE_TIME = "current_date_time";

private final String templateString;
private final PromptTemplateFactory.Template template;
private final Clock clock;

Expand All @@ -49,10 +51,18 @@ public PromptTemplate(String template) {
* @param clock the clock to use for the special variables.
*/
PromptTemplate(String template, Clock clock) {
this.templateString = ensureNotBlank(template, "template");
this.template = FACTORY.create(() -> template);
this.clock = ensureNotNull(clock, "clock");
}

/**
* @return A prompt template string.
*/
public String template() {
return templateString;
}

/**
* Applies a value to a template containing a single variable. The single variable should have the name {{it}}.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package dev.langchain4j.model.scoring;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;

import java.util.List;

import static dev.langchain4j.internal.ValidationUtils.ensureEq;
import static java.util.Collections.singletonList;

/**
* Represents a model capable of scoring a text against a query.
* <br>
* Useful for identifying the most relevant texts when scoring multiple texts against the same query.
* <br>
* The scoring model can be employed for re-ranking purposes.
*/
public interface ScoringModel {

/**
* Scores a given text against a given query.
*
* @param text The text to be scored.
* @param query The query against which to score the text.
* @return the score.
*/
default Response<Double> score(String text, String query) {
return score(TextSegment.from(text), query);
}

/**
* Scores a given {@link TextSegment} against a given query.
*
* @param segment The {@link TextSegment} to be scored.
* @param query The query against which to score the segment.
* @return the score.
*/
default Response<Double> score(TextSegment segment, String query) {
Response<List<Double>> response = scoreAll(singletonList(segment), query);
ensureEq(response.content().size(), 1,
"Expected a single score, but received %d", response.content().size());
return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason());
}

/**
* Scores all provided {@link TextSegment}s against a given query.
*
* @param segments The list of {@link TextSegment}s to score.
* @param query The query against which to score the segments.
* @return the list of scores. The order of scores corresponds to the order of {@link TextSegment}s.
*/
Response<List<Double>> scoreAll(List<TextSegment> segments, String query);
}
Loading

0 comments on commit 14fb985

Please sign in to comment.