|
7 | 7 | import java.util.HashMap; |
8 | 8 | import java.util.List; |
9 | 9 | import java.util.Map; |
| 10 | +import java.util.Set; |
10 | 11 |
|
11 | 12 | /** |
12 | 13 | * Base class for message history implementations. |
|
15 | 16 | */ |
16 | 17 | public abstract class BaseMessageHistory { |
17 | 18 |
|
| 19 | + /** Valid role values for message filtering. */ |
| 20 | + private static final Set<String> VALID_ROLES = Set.of("system", "user", "llm", "tool"); |
| 21 | + |
18 | 22 | protected final String name; |
19 | 23 | protected final String sessionTag; |
20 | 24 |
|
@@ -55,10 +59,14 @@ protected BaseMessageHistory(String name, String sessionTag) { |
55 | 59 | * @param raw Whether to return the full Redis hash entry or just the role/content/tool_call_id. |
56 | 60 | * @param sessionTag Tag of the entries linked to a specific conversation session. Defaults to |
57 | 61 | * instance ULID. |
| 62 | + * @param role Filter messages by role(s). Can be a single role string ("system", "user", "llm", |
| 63 | + * "tool"), a List of role strings, or null for no filtering. |
58 | 64 | * @return List of messages (either as text strings or maps depending on asText parameter) |
59 | | - * @throws IllegalArgumentException if topK is not an integer greater than or equal to 0 |
| 65 | + * @throws IllegalArgumentException if topK is not an integer greater than or equal to 0, or if |
| 66 | + * role contains invalid values |
60 | 67 | */ |
61 | | - public abstract <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessionTag); |
| 68 | + public abstract <T> List<T> getRecent( |
| 69 | + int topK, boolean asText, boolean raw, String sessionTag, Object role); |
62 | 70 |
|
63 | 71 | /** |
64 | 72 | * Insert a prompt:response pair into the message history. |
@@ -117,6 +125,60 @@ protected <T> List<T> formatContext(List<Map<String, Object>> messages, boolean |
117 | 125 | return context; |
118 | 126 | } |
119 | 127 |
|
| 128 | + /** |
| 129 | + * Validate and normalize role parameter for filtering messages. |
| 130 | + * |
| 131 | + * <p>Matches Python _validate_roles from base_history.py (lines 90-128) |
| 132 | + * |
| 133 | + * @param role A single role string, List of roles, or null |
| 134 | + * @return List of valid role strings if role is provided, null otherwise |
| 135 | + * @throws IllegalArgumentException if role contains invalid values or is the wrong type |
| 136 | + */ |
| 137 | + @SuppressWarnings("unchecked") |
| 138 | + protected List<String> validateRoles(Object role) { |
| 139 | + if (role == null) { |
| 140 | + return null; |
| 141 | + } |
| 142 | + |
| 143 | + // Handle single role string |
| 144 | + if (role instanceof String) { |
| 145 | + String roleStr = (String) role; |
| 146 | + if (!VALID_ROLES.contains(roleStr)) { |
| 147 | + throw new IllegalArgumentException( |
| 148 | + String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES)); |
| 149 | + } |
| 150 | + return List.of(roleStr); |
| 151 | + } |
| 152 | + |
| 153 | + // Handle list of roles |
| 154 | + if (role instanceof List) { |
| 155 | + List<?> roleList = (List<?>) role; |
| 156 | + |
| 157 | + if (roleList.isEmpty()) { |
| 158 | + throw new IllegalArgumentException("roles cannot be empty"); |
| 159 | + } |
| 160 | + |
| 161 | + // Validate all roles in the list |
| 162 | + List<String> validatedRoles = new ArrayList<>(); |
| 163 | + for (Object r : roleList) { |
| 164 | + if (!(r instanceof String)) { |
| 165 | + throw new IllegalArgumentException( |
| 166 | + "role list must contain only strings, found: " + r.getClass().getSimpleName()); |
| 167 | + } |
| 168 | + String roleStr = (String) r; |
| 169 | + if (!VALID_ROLES.contains(roleStr)) { |
| 170 | + throw new IllegalArgumentException( |
| 171 | + String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES)); |
| 172 | + } |
| 173 | + validatedRoles.add(roleStr); |
| 174 | + } |
| 175 | + |
| 176 | + return validatedRoles; |
| 177 | + } |
| 178 | + |
| 179 | + throw new IllegalArgumentException("role must be a String, List<String>, or null"); |
| 180 | + } |
| 181 | + |
120 | 182 | public String getName() { |
121 | 183 | return name; |
122 | 184 | } |
|
0 commit comments