Skip to content

Add builder pattern and order parameter to advisors #1507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.util.Assert;
import org.stringtemplate.v4.compiler.CodeGenerator.includeExpr_return;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -56,12 +57,20 @@ public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor,

private final boolean protectFromBlocking;

private final int order;

protected AbstractChatMemoryAdvisor(T chatMemory) {
this(chatMemory, DEFAULT_CHAT_MEMORY_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, true);
}

protected AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize,
boolean protectFromBlocking) {
this(chatMemory, defaultConversationId, defaultChatMemoryRetrieveSize, protectFromBlocking,
Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}

protected AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize,
boolean protectFromBlocking, int order) {

Assert.notNull(chatMemory, "The chatMemory must not be null!");
Assert.hasText(defaultConversationId, "The conversationId must not be empty!");
Expand All @@ -71,6 +80,7 @@ protected AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId,
this.defaultConversationId = defaultConversationId;
this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize;
this.protectFromBlocking = protectFromBlocking;
this.order = order;
}

@Override
Expand All @@ -80,11 +90,11 @@ public String getName() {

@Override
public int getOrder() {
// The (Ordered.HIGHEST_PRECEDENCE + 1000) value ensures this order has lower
// priority (e.g. precedences) than the internal Spring AI advisors. It leaves
// room (1000 slots) for the user to plug in their own advisors with higher
// by default the (Ordered.HIGHEST_PRECEDENCE + 1000) value ensures this order has
// lower priority (e.g. precedences) than the internal Spring AI advisors. It
// leaves room (1000 slots) for the user to plug in their own advisors with higher
// priority.
return Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;
return this.order;
}

protected T getChatMemoryStore() {
Expand Down Expand Up @@ -118,5 +128,43 @@ protected Flux<AdvisedResponse> doNextWithProtectFromBlockingBefore(AdvisedReque
: chain.nextAroundStream(beforeAdvise.apply(advisedRequest));
}

public static abstract class AbstractBuilder<T> {

protected String conversationId = DEFAULT_CHAT_MEMORY_CONVERSATION_ID;

protected int chatMemoryRetrieveSize = DEFAULT_CHAT_MEMORY_RESPONSE_SIZE;

protected boolean protectFromBlocking = true;

protected int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER;

protected T chatMemory;

protected AbstractBuilder(T chatMemory) {
this.chatMemory = chatMemory;
}

public AbstractBuilder withConversationId(String conversationId) {
this.conversationId = conversationId;
return this;
}

public AbstractBuilder withChatMemoryRetrieveSize(int chatMemoryRetrieveSize) {
this.chatMemoryRetrieveSize = chatMemoryRetrieveSize;
return this;
}

public AbstractBuilder withProtectFromBlocking(boolean protectFromBlocking) {
this.protectFromBlocking = protectFromBlocking;
return this;
}

public AbstractBuilder withOrder(int order) {
this.order = order;
return this;
}

abstract public <T> AbstractChatMemoryAdvisor<T> build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.memory.ChatMemory;
Expand All @@ -43,7 +44,12 @@ public MessageChatMemoryAdvisor(ChatMemory chatMemory) {
}

public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize) {
super(chatMemory, defaultConversationId, chatHistoryWindowSize, true);
this(chatMemory, defaultConversationId, chatHistoryWindowSize, Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}

public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize,
int order) {
super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order);
}

@Override
Expand Down Expand Up @@ -101,4 +107,21 @@ private void observeAfter(AdvisedResponse advisedResponse) {
this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages);
}

public static Builder builder(ChatMemory chatMemory) {
return new Builder(chatMemory);
}

public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<ChatMemory> {

protected Builder(ChatMemory chatMemory) {
super(chatMemory);
}

public MessageChatMemoryAdvisor build() {
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
this.order);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.memory.ChatMemory;
Expand Down Expand Up @@ -66,7 +67,13 @@ public PromptChatMemoryAdvisor(ChatMemory chatMemory, String systemTextAdvise) {

public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize,
String systemTextAdvise) {
super(chatMemory, defaultConversationId, chatHistoryWindowSize, true);
this(chatMemory, defaultConversationId, chatHistoryWindowSize, systemTextAdvise,
Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}

public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize,
String systemTextAdvise, int order) {
super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order);
this.systemTextAdvise = systemTextAdvise;
}

Expand Down Expand Up @@ -133,4 +140,28 @@ private void observeAfter(AdvisedResponse advisedResponse) {
this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages);
}

public static Builder builder(ChatMemory chatMemory) {
return new Builder(chatMemory);
}

public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<ChatMemory> {

private String systemTextAdvise = DEFAULT_SYSTEM_TEXT_ADVISE;

protected Builder(ChatMemory chatMemory) {
super(chatMemory);
}

public Builder withSystemTextAdvise(String systemTextAdvise) {
this.systemTextAdvise = systemTextAdvise;
return this;
}

public PromptChatMemoryAdvisor build() {
return new PromptChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
this.systemTextAdvise, this.order);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -61,6 +61,8 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
the user that you can't answer the question.
""";

private static final int DEFAULT_ORDER = 0;

private final VectorStore vectorStore;

private final String userTextAdvise;
Expand All @@ -73,6 +75,8 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv

private final boolean protectFromBlocking;

private final int order;

/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
Expand Down Expand Up @@ -121,6 +125,25 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
*/
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
boolean protectFromBlocking) {
this(vectorStore, searchRequest, userTextAdvise, protectFromBlocking, DEFAULT_ORDER);
}

/**
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
* combines it with the user's text.
* @param vectorStore The vector store to use
* @param searchRequest The search request defined using the portable filter
* expression syntax
* @param userTextAdvise the user text to append to the existing user prompt. The text
* should contain a placeholder named "question_answer_context".
* @param protectFromBlocking if true the advisor will protect the execution from
* blocking threads. If false the advisor will not protect the execution from blocking
* threads. This is useful when the advisor is used in a non-blocking environment. It
* is true by default.
* @param order the order of the advisor.
*/
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
boolean protectFromBlocking, int order) {

Assert.notNull(vectorStore, "The vectorStore must not be null!");
Assert.notNull(searchRequest, "The searchRequest must not be null!");
Expand All @@ -130,6 +153,7 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques
this.searchRequest = searchRequest;
this.userTextAdvise = userTextAdvise;
this.protectFromBlocking = protectFromBlocking;
this.order = order;
}

@Override
Expand All @@ -139,7 +163,7 @@ public String getName() {

@Override
public int getOrder() {
return 0;
return this.order;
}

@Override
Expand Down Expand Up @@ -249,6 +273,8 @@ public static class Builder {

private boolean protectFromBlocking = true;

private int order = DEFAULT_ORDER;

private Builder(VectorStore vectorStore) {
Assert.notNull(vectorStore, "The vectorStore must not be null!");
this.vectorStore = vectorStore;
Expand All @@ -271,9 +297,14 @@ public Builder withProtectFromBlocking(boolean protectFromBlocking) {
return this;
}

public Builder withOrder(int order) {
this.order = order;
return this;
}

public QuestionAnswerAdvisor build() {
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.userTextAdvise,
this.protectFromBlocking);
this.protectFromBlocking, this.order);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -78,7 +79,13 @@ public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConve

public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
int chatHistoryWindowSize, String systemTextAdvise) {
super(vectorStore, defaultConversationId, chatHistoryWindowSize, true);
this(vectorStore, defaultConversationId, chatHistoryWindowSize, systemTextAdvise,
Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}

public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId,
int chatHistoryWindowSize, String systemTextAdvise, int order) {
super(vectorStore, defaultConversationId, chatHistoryWindowSize, true, order);
this.systemTextAdvise = systemTextAdvise;
}

Expand Down Expand Up @@ -168,4 +175,29 @@ else if (message instanceof AssistantMessage assistantMessage) {
return docs;
}

public static Builder builder(VectorStore chatMemory) {
return new Builder(chatMemory);
}

public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<VectorStore> {

private String systemTextAdvise = DEFAULT_SYSTEM_TEXT_ADVISE;

protected Builder(VectorStore chatMemory) {
super(chatMemory);
}

public Builder withSystemTextAdvise(String systemTextAdvise) {
this.systemTextAdvise = systemTextAdvise;
return this;
}

@Override
public VectorStoreChatMemoryAdvisor build() {
return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize,
this.systemTextAdvise);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ The following advisor implementations use the `ChatMemory` interface to advice t

* `MessageChatMemoryAdvisor` : Memory is retrieved and added as a collection of messages to the prompt
* `PromptChatMemoryAdvisor` : Memory is retrieved and added into the prompt's system text.
* `VectorStoreChatMemoryAdvisor` : The constructor `VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize)` lets you specify the VectorStore to retrieve the chat history from, the unique conversation ID, the size of the chat history to be retrieved in token size.
* `VectorStoreChatMemoryAdvisor` : The constructor `VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize, int order)` lets you specify the VectorStore to retrieve the chat history from, the unique conversation ID, the size of the chat history to be retrieved in token size.
The VectorStoreChatMemoryAdvisor.builder() method lets you specify the default conversation ID, the chat history window size, and the order of the chat history to be retrieved.

A sample `@Service` implementation that uses several advisors is shown below.

Expand Down Expand Up @@ -452,10 +453,9 @@ public class CustomerSupportAssistant {
If there is a charge for the change, you MUST ask the user to consent before proceeding.
""")
.defaultAdvisors(
new PromptChatMemoryAdvisor(chatMemory),
// new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORY
new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORY
new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()), // RAG
new LoggingAdvisor())
new SimpleLoggerAdvisor())
.defaultFunctions("getBookingDetails", "changeBooking", "cancelBooking") // FUNCTION CALLING
.build();
}
Expand Down