forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Foundation for advanced RAG (langchain4j#538)
So far, LangChain4j had only a simple (a.k.a., naive) RAG implementation: a single `Retriever` was invoked on each interaction with the LLM, and all retrieved `TextSegments` were appended to the end of the `UserMessage`. This approach was very limiting. This PR introduces support for much more advanced RAG use cases. The design and mental model are inspired by [this article](https://blog.langchain.dev/deconstructing-rag/) and [this paper](https://arxiv.org/abs/2312.10997), making it advisable to read the article. This PR introduces a `RetrievalAugmentor` interface responsible for augmenting a `UserMessage` with relevant content before sending it to the LLM. The `RetrievalAugmentor` can be used with both `AiServices` and `ConversationalRetrievalChain`, as well as stand-alone. A default implementation of `RetrievalAugmentor` (`DefaultRetrievalAugmentor`) is provided with the library and is suggested as a good starting point. However, users are not limited to it and can have more freedom with their own custom implementations. `DefaultRetrievalAugmentor` decomposes the entire RAG flow into more granular steps and base components: - `QueryTransformer` - `QueryRouter` - `ContentRetriever` (the old `Retriever` is now deprecated) - `ContentAggregator` - `ContentInjector` This modular design aims to separate concerns and simplify development, testing, and evaluation. Most (if not all) currently known and proven RAG techniques can be represented as one or multiple base components listed above. Here is how the decomposed RAG flow can be visualized: ![advanced-rag](https://github.com/langchain4j/langchain4j/assets/132277850/b699077d-dabf-4768-a241-3fcd9ab0286c) This mental and software model aims to simplify the thinking, reasoning, and implementation of advanced RAG flows. Each base component listed above has a sensible and simple default implementation configured in `DefaultRetrievalAugmentor` by default but can be overridden by more sophisticated implementations (provided by the library out-of-the-box) as well as custom ones. The list of implementations is expected to grow over time as we discover new techniques and implement existing proven ones. This PR also introduces out-of-the-box support for the following proven RAG techniques: - Query expansion - Query compression - Query routing using LLM - [Reciprocal Rank Fusion](https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking) - Re-ranking ([Cohere Rerank](https://docs.cohere.com/docs/reranking) integration is coming in a [separate PR](langchain4j#539)).
- Loading branch information
1 parent
5980f67
commit 14fb985
Showing
59 changed files
with
4,813 additions
and
151 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
langchain4j-core/src/main/java/dev/langchain4j/model/scoring/ScoringModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package dev.langchain4j.model.scoring; | ||
|
||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.model.output.Response; | ||
|
||
import java.util.List; | ||
|
||
import static dev.langchain4j.internal.ValidationUtils.ensureEq; | ||
import static java.util.Collections.singletonList; | ||
|
||
/** | ||
* Represents a model capable of scoring a text against a query. | ||
* <br> | ||
* Useful for identifying the most relevant texts when scoring multiple texts against the same query. | ||
* <br> | ||
* The scoring model can be employed for re-ranking purposes. | ||
*/ | ||
public interface ScoringModel { | ||
|
||
/** | ||
* Scores a given text against a given query. | ||
* | ||
* @param text The text to be scored. | ||
* @param query The query against which to score the text. | ||
* @return the score. | ||
*/ | ||
default Response<Double> score(String text, String query) { | ||
return score(TextSegment.from(text), query); | ||
} | ||
|
||
/** | ||
* Scores a given {@link TextSegment} against a given query. | ||
* | ||
* @param segment The {@link TextSegment} to be scored. | ||
* @param query The query against which to score the segment. | ||
* @return the score. | ||
*/ | ||
default Response<Double> score(TextSegment segment, String query) { | ||
Response<List<Double>> response = scoreAll(singletonList(segment), query); | ||
ensureEq(response.content().size(), 1, | ||
"Expected a single score, but received %d", response.content().size()); | ||
return Response.from(response.content().get(0), response.tokenUsage(), response.finishReason()); | ||
} | ||
|
||
/** | ||
* Scores all provided {@link TextSegment}s against a given query. | ||
* | ||
* @param segments The list of {@link TextSegment}s to score. | ||
* @param query The query against which to score the segments. | ||
* @return the list of scores. The order of scores corresponds to the order of {@link TextSegment}s. | ||
*/ | ||
Response<List<Double>> scoreAll(List<TextSegment> segments, String query); | ||
} |
Oops, something went wrong.