Skip to content

Commit 0a23c39

Browse files
committed
enhance SafeGuardAdvisor
Signed-off-by: Karson To <karsontao@hotmail.com>
1 parent f78b549 commit 0a23c39

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.List;
2020
import java.util.Map;
21+
import java.util.function.Predicate;
2122

2223
import reactor.core.publisher.Flux;
2324

@@ -37,6 +38,7 @@
3738
* An advisor that blocks the call to the model provider if the user input contains any of
3839
* the sensitive words.
3940
*
41+
* @author Karson To
4042
* @author Christian Tzolov
4143
* @author Ilayaperumal Gopinathan
4244
* @author Thomas Vitale
@@ -49,19 +51,19 @@ public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor {
4951
private static final int DEFAULT_ORDER = 0;
5052

5153
private final String failureResponse;
52-
53-
private final List<String> sensitiveWords;
54+
55+
private final Predicate<String> contentValidator;
5456

5557
private final int order;
5658

57-
public SafeGuardAdvisor(List<String> sensitiveWords) {
58-
this(sensitiveWords, DEFAULT_FAILURE_RESPONSE, DEFAULT_ORDER);
59+
public SafeGuardAdvisor(Predicate<String> contentValidator) {
60+
this(contentValidator, DEFAULT_FAILURE_RESPONSE, DEFAULT_ORDER);
5961
}
6062

61-
public SafeGuardAdvisor(List<String> sensitiveWords, String failureResponse, int order) {
62-
Assert.notNull(sensitiveWords, "Sensitive words must not be null!");
63+
public SafeGuardAdvisor(Predicate<String> contentValidator, String failureResponse, int order) {
64+
Assert.notNull(contentValidator, "SContent validator must not be null!");
6365
Assert.notNull(failureResponse, "Failure response must not be null!");
64-
this.sensitiveWords = sensitiveWords;
66+
this.contentValidator = contentValidator;
6567
this.failureResponse = failureResponse;
6668
this.order = order;
6769
}
@@ -76,8 +78,7 @@ public String getName() {
7678

7779
@Override
7880
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
79-
if (!CollectionUtils.isEmpty(this.sensitiveWords)
80-
&& this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) {
81+
if (contentValidator.test(chatClientRequest.prompt().getContents())) {
8182
return createFailureResponse(chatClientRequest);
8283
}
8384

@@ -87,8 +88,7 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
8788
@Override
8889
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
8990
StreamAdvisorChain streamAdvisorChain) {
90-
if (!CollectionUtils.isEmpty(this.sensitiveWords)
91-
&& this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) {
91+
if (contentValidator.test(chatClientRequest.prompt().getContents())) {
9292
return Flux.just(createFailureResponse(chatClientRequest));
9393
}
9494

@@ -111,7 +111,7 @@ public int getOrder() {
111111

112112
public static final class Builder {
113113

114-
private List<String> sensitiveWords;
114+
private Predicate<String> contentValidator;
115115

116116
private String failureResponse = DEFAULT_FAILURE_RESPONSE;
117117

@@ -120,8 +120,8 @@ public static final class Builder {
120120
private Builder() {
121121
}
122122

123-
public Builder sensitiveWords(List<String> sensitiveWords) {
124-
this.sensitiveWords = sensitiveWords;
123+
public Builder sensitiveWords(Predicate<String> contentValidator) {
124+
this.contentValidator = contentValidator;
125125
return this;
126126
}
127127

@@ -136,7 +136,7 @@ public Builder order(int order) {
136136
}
137137

138138
public SafeGuardAdvisor build() {
139-
return new SafeGuardAdvisor(this.sensitiveWords, this.failureResponse, this.order);
139+
return new SafeGuardAdvisor(this.contentValidator, this.failureResponse, this.order);
140140
}
141141

142142
}

0 commit comments

Comments
 (0)