18
18
19
19
import java .util .List ;
20
20
import java .util .Map ;
21
+ import java .util .function .Predicate ;
21
22
22
23
import reactor .core .publisher .Flux ;
23
24
37
38
* An advisor that blocks the call to the model provider if the user input contains any of
38
39
* the sensitive words.
39
40
*
41
+ * @author Karson To
40
42
* @author Christian Tzolov
41
43
* @author Ilayaperumal Gopinathan
42
44
* @author Thomas Vitale
@@ -49,19 +51,19 @@ public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor {
49
51
private static final int DEFAULT_ORDER = 0 ;
50
52
51
53
private final String failureResponse ;
52
-
53
- private final List <String > sensitiveWords ;
54
+
55
+ private final Predicate <String > contentValidator ;
54
56
55
57
private final int order ;
56
58
57
- public SafeGuardAdvisor (List <String > sensitiveWords ) {
58
- this (sensitiveWords , DEFAULT_FAILURE_RESPONSE , DEFAULT_ORDER );
59
+ public SafeGuardAdvisor (Predicate <String > contentValidator ) {
60
+ this (contentValidator , DEFAULT_FAILURE_RESPONSE , DEFAULT_ORDER );
59
61
}
60
62
61
- public SafeGuardAdvisor (List <String > sensitiveWords , String failureResponse , int order ) {
62
- Assert .notNull (sensitiveWords , "Sensitive words must not be null!" );
63
+ public SafeGuardAdvisor (Predicate <String > contentValidator , String failureResponse , int order ) {
64
+ Assert .notNull (contentValidator , "SContent validator must not be null!" );
63
65
Assert .notNull (failureResponse , "Failure response must not be null!" );
64
- this .sensitiveWords = sensitiveWords ;
66
+ this .contentValidator = contentValidator ;
65
67
this .failureResponse = failureResponse ;
66
68
this .order = order ;
67
69
}
@@ -76,8 +78,7 @@ public String getName() {
76
78
77
79
@ Override
78
80
public ChatClientResponse adviseCall (ChatClientRequest chatClientRequest , CallAdvisorChain callAdvisorChain ) {
79
- if (!CollectionUtils .isEmpty (this .sensitiveWords )
80
- && this .sensitiveWords .stream ().anyMatch (w -> chatClientRequest .prompt ().getContents ().contains (w ))) {
81
+ if (contentValidator .test (chatClientRequest .prompt ().getContents ())) {
81
82
return createFailureResponse (chatClientRequest );
82
83
}
83
84
@@ -87,8 +88,7 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
87
88
@ Override
88
89
public Flux <ChatClientResponse > adviseStream (ChatClientRequest chatClientRequest ,
89
90
StreamAdvisorChain streamAdvisorChain ) {
90
- if (!CollectionUtils .isEmpty (this .sensitiveWords )
91
- && this .sensitiveWords .stream ().anyMatch (w -> chatClientRequest .prompt ().getContents ().contains (w ))) {
91
+ if (contentValidator .test (chatClientRequest .prompt ().getContents ())) {
92
92
return Flux .just (createFailureResponse (chatClientRequest ));
93
93
}
94
94
@@ -111,7 +111,7 @@ public int getOrder() {
111
111
112
112
public static final class Builder {
113
113
114
- private List <String > sensitiveWords ;
114
+ private Predicate <String > contentValidator ;
115
115
116
116
private String failureResponse = DEFAULT_FAILURE_RESPONSE ;
117
117
@@ -120,8 +120,8 @@ public static final class Builder {
120
120
private Builder () {
121
121
}
122
122
123
- public Builder sensitiveWords (List <String > sensitiveWords ) {
124
- this .sensitiveWords = sensitiveWords ;
123
+ public Builder sensitiveWords (Predicate <String > contentValidator ) {
124
+ this .contentValidator = contentValidator ;
125
125
return this ;
126
126
}
127
127
@@ -136,7 +136,7 @@ public Builder order(int order) {
136
136
}
137
137
138
138
public SafeGuardAdvisor build () {
139
- return new SafeGuardAdvisor (this .sensitiveWords , this .failureResponse , this .order );
139
+ return new SafeGuardAdvisor (this .contentValidator , this .failureResponse , this .order );
140
140
}
141
141
142
142
}
0 commit comments