Skip to content

Commit 31c044a

Browse files
committed
Modular RAG: Retrieval with Vector Stores
* Establish new package for Modular RAG components. * Add new Query API, representing a query in the context of a RAG flow. * Define Retrieval package for the RAG building blocks handling the data retrieval operations. * Relocate DocumentRetriever to Retrieval package and implement VectorStoreDocumentRetriever. * Introduce RetrievalAugmentationAdvisor as the successor of QuestionAnswerAdvisor. It uses the Retrieval building blocks described in the previous point. * Make Advisor APIs null-safe and update tests accordingly. Relates to gh-#1603 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent 9c7af56 commit 31c044a

File tree

15 files changed

+941
-69
lines changed

15 files changed

+941
-69
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.client.advisor;
18+
19+
import java.util.HashMap;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.function.Predicate;
23+
import java.util.stream.Collectors;
24+
25+
import reactor.core.publisher.Flux;
26+
import reactor.core.publisher.Mono;
27+
import reactor.core.scheduler.Schedulers;
28+
29+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
30+
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
31+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
32+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
33+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
34+
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
35+
import org.springframework.ai.chat.messages.UserMessage;
36+
import org.springframework.ai.chat.model.ChatResponse;
37+
import org.springframework.ai.chat.prompt.PromptTemplate;
38+
import org.springframework.ai.document.Document;
39+
import org.springframework.ai.model.Content;
40+
import org.springframework.ai.rag.Query;
41+
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
42+
import org.springframework.lang.Nullable;
43+
import org.springframework.util.Assert;
44+
import org.springframework.util.StringUtils;
45+
46+
/**
47+
* This advisor implements common Retrieval Augmented Generation (RAG) flows using the
48+
* building blocks defined in the {@link org.springframework.ai.rag} package and following
49+
* the Modular RAG Architecture.
50+
* <p>
51+
* It's the successor of the {@link QuestionAnswerAdvisor}.
52+
*
53+
* @author Christian Tzolov
54+
* @author Thomas Vitale
55+
* @since 1.0.0
56+
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
57+
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
58+
*/
59+
public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
60+
61+
public static final String DOCUMENT_CONTEXT = "rag_document_context";
62+
63+
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
64+
{query}
65+
66+
Context information is below. Use this information to answer the user query.
67+
68+
---------------------
69+
{context}
70+
---------------------
71+
72+
Given the context and provided history information and not prior knowledge,
73+
reply to the user query. If the answer is not in the context, inform
74+
the user that you can't answer the query.
75+
""");
76+
77+
private final DocumentRetriever documentRetriever;
78+
79+
private final PromptTemplate promptTemplate;
80+
81+
private final boolean protectFromBlocking;
82+
83+
private final int order;
84+
85+
public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable PromptTemplate promptTemplate,
86+
@Nullable Boolean protectFromBlocking, @Nullable Integer order) {
87+
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
88+
this.documentRetriever = documentRetriever;
89+
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
90+
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
91+
this.order = order != null ? order : 0;
92+
}
93+
94+
public static Builder builder() {
95+
return new Builder();
96+
}
97+
98+
@Override
99+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
100+
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
101+
Assert.notNull(chain, "chain cannot be null");
102+
103+
AdvisedRequest processedAdvisedRequest = before(advisedRequest);
104+
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
105+
return after(advisedResponse);
106+
}
107+
108+
@Override
109+
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
110+
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
111+
Assert.notNull(chain, "chain cannot be null");
112+
113+
// This can be executed by both blocking and non-blocking Threads
114+
// E.g. a command line or Tomcat blocking Thread implementation
115+
// or by a WebFlux dispatch in a non-blocking manner.
116+
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
117+
// @formatter:off
118+
Mono.just(advisedRequest)
119+
.publishOn(Schedulers.boundedElastic())
120+
.map(this::before)
121+
.flatMapMany(chain::nextAroundStream)
122+
: chain.nextAroundStream(before(advisedRequest));
123+
// @formatter:on
124+
125+
return advisedResponses.map(ar -> {
126+
if (onFinishReason().test(ar)) {
127+
ar = after(ar);
128+
}
129+
return ar;
130+
});
131+
}
132+
133+
private AdvisedRequest before(AdvisedRequest request) {
134+
Map<String, Object> context = new HashMap<>(request.adviseContext());
135+
136+
// 0. Create a query from the user text and parameters.
137+
Query query = new Query(new PromptTemplate(request.userText(), request.userParams()).render());
138+
139+
// 1. Retrieve similar documents for the original query.
140+
List<Document> documents = this.documentRetriever.retrieve(query);
141+
context.put(DOCUMENT_CONTEXT, documents);
142+
143+
// 2. Combine retrieved documents.
144+
String documentContext = documents.stream()
145+
.map(Content::getContent)
146+
.collect(Collectors.joining(System.lineSeparator()));
147+
148+
// 3. Define augmentation prompt parameters.
149+
Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext);
150+
151+
// 4. Augment user prompt with the context data.
152+
UserMessage augmentedUserMessage = (UserMessage) this.promptTemplate.createMessage(promptParameters);
153+
154+
return AdvisedRequest.from(request)
155+
.withUserText(augmentedUserMessage.getContent())
156+
.withAdviseContext(context)
157+
.build();
158+
}
159+
160+
private AdvisedResponse after(AdvisedResponse advisedResponse) {
161+
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
162+
chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
163+
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
164+
}
165+
166+
private Predicate<AdvisedResponse> onFinishReason() {
167+
return advisedResponse -> advisedResponse.response()
168+
.getResults()
169+
.stream()
170+
.anyMatch(result -> result != null && result.getMetadata() != null
171+
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
172+
}
173+
174+
@Override
175+
public String getName() {
176+
return this.getClass().getSimpleName();
177+
}
178+
179+
@Override
180+
public int getOrder() {
181+
return this.order;
182+
}
183+
184+
public static final class Builder {
185+
186+
private DocumentRetriever documentRetriever;
187+
188+
private PromptTemplate promptTemplate;
189+
190+
private Boolean protectFromBlocking;
191+
192+
private Integer order;
193+
194+
private Builder() {
195+
}
196+
197+
public Builder documentRetriever(DocumentRetriever documentRetriever) {
198+
this.documentRetriever = documentRetriever;
199+
return this;
200+
}
201+
202+
public Builder promptTemplate(PromptTemplate promptTemplate) {
203+
this.promptTemplate = promptTemplate;
204+
return this;
205+
}
206+
207+
public Builder protectFromBlocking(Boolean protectFromBlocking) {
208+
this.protectFromBlocking = protectFromBlocking;
209+
return this;
210+
}
211+
212+
public Builder order(Integer order) {
213+
this.order = order;
214+
return this;
215+
}
216+
217+
public RetrievalAugmentationAdvisor build() {
218+
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.promptTemplate,
219+
this.protectFromBlocking, this.order);
220+
}
221+
222+
}
223+
224+
}

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@
3434
import org.springframework.ai.model.Media;
3535
import org.springframework.ai.model.function.FunctionCallback;
3636
import org.springframework.ai.model.function.FunctionCallingOptions;
37+
import org.springframework.lang.Nullable;
38+
import org.springframework.util.Assert;
3739
import org.springframework.util.CollectionUtils;
3840
import org.springframework.util.StringUtils;
3941

4042
/**
4143
* The data of the chat client request that can be modified before the execution of the
4244
* ChatClient's call method
4345
*
44-
* @author Christian Tzolov
45-
* @since 1.0.0
4646
* @param chatModel the chat model used
4747
* @param userText the text provided by the user
4848
* @param systemText the text provided by the system
@@ -57,13 +57,53 @@
5757
* @param advisorParams the map of advisor parameters
5858
* @param adviseContext the map of advise context
5959
* @param toolContext the tool context
60+
* @author Christian Tzolov
61+
* @author Thomas Vitale
62+
* @since 1.0.0
6063
*/
61-
public record AdvisedRequest(ChatModel chatModel, String userText, String systemText, ChatOptions chatOptions,
62-
List<Media> media, List<String> functionNames, List<FunctionCallback> functionCallbacks, List<Message> messages,
63-
Map<String, Object> userParams, Map<String, Object> systemParams, List<Advisor> advisors,
64-
Map<String, Object> advisorParams, Map<String, Object> adviseContext, Map<String, Object> toolContext) {
64+
public record AdvisedRequest(
65+
// @formatter:off
66+
ChatModel chatModel,
67+
String userText,
68+
@Nullable
69+
String systemText,
70+
@Nullable
71+
ChatOptions chatOptions,
72+
List<Media> media,
73+
List<String> functionNames,
74+
List<FunctionCallback> functionCallbacks,
75+
List<Message> messages,
76+
Map<String, Object> userParams,
77+
Map<String, Object> systemParams,
78+
List<Advisor> advisors,
79+
Map<String, Object> advisorParams,
80+
Map<String, Object> adviseContext,
81+
Map<String, Object> toolContext
82+
// @formatter:on
83+
) {
84+
85+
public AdvisedRequest {
86+
Assert.notNull(chatModel, "chatModel cannot be null");
87+
Assert.hasText(userText, "userText cannot be null or empty");
88+
Assert.notNull(media, "media cannot be null");
89+
Assert.notNull(functionNames, "functionNames cannot be null");
90+
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
91+
Assert.notNull(messages, "messages cannot be null");
92+
Assert.notNull(userParams, "userParams cannot be null");
93+
Assert.notNull(systemParams, "systemParams cannot be null");
94+
Assert.notNull(advisors, "advisors cannot be null");
95+
Assert.notNull(advisorParams, "advisorParams cannot be null");
96+
Assert.notNull(adviseContext, "adviseContext cannot be null");
97+
Assert.notNull(toolContext, "toolContext cannot be null");
98+
}
99+
100+
public static Builder builder() {
101+
return new Builder();
102+
}
65103

66104
public static Builder from(AdvisedRequest from) {
105+
Assert.notNull(from, "AdvisedRequest cannot be null");
106+
67107
Builder builder = new Builder();
68108
builder.chatModel = from.chatModel;
69109
builder.userText = from.userText;
@@ -79,23 +119,18 @@ public static Builder from(AdvisedRequest from) {
79119
builder.advisorParams = from.advisorParams;
80120
builder.adviseContext = from.adviseContext;
81121
builder.toolContext = from.toolContext;
82-
83122
return builder;
84123
}
85124

86-
public static Builder builder() {
87-
return new Builder();
88-
}
89-
90125
public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) {
126+
Assert.notNull(contextTransform, "contextTransform cannot be null");
91127
return from(this)
92128
.withAdviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext))))
93129
.build();
94130
}
95131

96132
public Prompt toPrompt() {
97-
98-
var messages = new ArrayList<Message>(this.messages());
133+
var messages = new ArrayList<>(this.messages());
99134

100135
String processedSystemText = this.systemText();
101136
if (StringUtils.hasText(processedSystemText)) {
@@ -111,7 +146,6 @@ public Prompt toPrompt() {
111146
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText();
112147

113148
if (StringUtils.hasText(processedUserText)) {
114-
115149
Map<String, Object> userParams = new HashMap<>(this.userParams());
116150
if (StringUtils.hasText(formatParam)) {
117151
userParams.put("spring_ai_soc_format", formatParam);
@@ -137,17 +171,15 @@ public Prompt toPrompt() {
137171
return new Prompt(messages, this.chatOptions());
138172
}
139173

140-
public static class Builder {
141-
142-
public Map<String, Object> toolContext = Map.of();
174+
public static final class Builder {
143175

144176
private ChatModel chatModel;
145177

146-
private String userText = "";
178+
private String userText;
147179

148-
private String systemText = "";
180+
private String systemText;
149181

150-
private ChatOptions chatOptions = null;
182+
private ChatOptions chatOptions;
151183

152184
private List<Media> media = List.of();
153185

@@ -167,6 +199,11 @@ public static class Builder {
167199

168200
private Map<String, Object> adviseContext = Map.of();
169201

202+
public Map<String, Object> toolContext = Map.of();
203+
204+
private Builder() {
205+
}
206+
170207
public Builder withChatModel(ChatModel chatModel) {
171208
this.chatModel = chatModel;
172209
return this;
@@ -202,11 +239,6 @@ public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
202239
return this;
203240
}
204241

205-
public Builder withToolContext(Map<String, Object> toolContext) {
206-
this.toolContext = toolContext;
207-
return this;
208-
}
209-
210242
public Builder withMessages(List<Message> messages) {
211243
this.messages = messages;
212244
return this;
@@ -237,6 +269,11 @@ public Builder withAdviseContext(Map<String, Object> adviseContext) {
237269
return this;
238270
}
239271

272+
public Builder withToolContext(Map<String, Object> toolContext) {
273+
this.toolContext = toolContext;
274+
return this;
275+
}
276+
240277
public AdvisedRequest build() {
241278
return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media,
242279
this.functionNames, this.functionCallbacks, this.messages, this.userParams, this.systemParams,

0 commit comments

Comments
 (0)