Skip to content

Commit 916faac

Browse files
bsboddenclaude
andcommitted
feat(messagehistory): add role filtering to getRecent() method (#349)
Implements role-based filtering for message retrieval, allowing users to filter messages by role type (system, user, llm, tool) when querying conversation history. Changes: - Add validateRoles() method to BaseMessageHistory for role validation - Update getRecent() signature to accept optional role parameter - Support single role string or List<String> of multiple roles - Implement role filtering in MessageHistory using Filter.tag() combinations - Maintain backward compatibility with existing getRecent() calls Tests: - Add RoleFilteringTest with 15 unit tests for validation logic - Add RoleFilteringIntegrationTest with 15 integration tests - Test single role, multiple roles, null role, and error cases - Test role filtering with top_k, session_tag, raw, and asText parameters Python reference: PR #387 - Role filtering for message history Ported from: tests/integration/test_role_filter_get_recent.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 4938e12 commit 916faac

File tree

4 files changed

+726
-5
lines changed

4 files changed

+726
-5
lines changed

core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import java.util.HashMap;
88
import java.util.List;
99
import java.util.Map;
10+
import java.util.Set;
1011

1112
/**
1213
* Base class for message history implementations.
@@ -15,6 +16,9 @@
1516
*/
1617
public abstract class BaseMessageHistory {
1718

19+
/** Valid role values for message filtering. */
20+
private static final Set<String> VALID_ROLES = Set.of("system", "user", "llm", "tool");
21+
1822
protected final String name;
1923
protected final String sessionTag;
2024

@@ -55,10 +59,14 @@ protected BaseMessageHistory(String name, String sessionTag) {
5559
* @param raw Whether to return the full Redis hash entry or just the role/content/tool_call_id.
5660
* @param sessionTag Tag of the entries linked to a specific conversation session. Defaults to
5761
* instance ULID.
62+
* @param role Filter messages by role(s). Can be a single role string ("system", "user", "llm",
63+
* "tool"), a List of role strings, or null for no filtering.
5864
* @return List of messages (either as text strings or maps depending on asText parameter)
59-
* @throws IllegalArgumentException if topK is not an integer greater than or equal to 0
65+
* @throws IllegalArgumentException if topK is not an integer greater than or equal to 0, or if
66+
* role contains invalid values
6067
*/
61-
public abstract <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessionTag);
68+
public abstract <T> List<T> getRecent(
69+
int topK, boolean asText, boolean raw, String sessionTag, Object role);
6270

6371
/**
6472
* Insert a prompt:response pair into the message history.
@@ -117,6 +125,60 @@ protected <T> List<T> formatContext(List<Map<String, Object>> messages, boolean
117125
return context;
118126
}
119127

128+
/**
129+
* Validate and normalize role parameter for filtering messages.
130+
*
131+
* <p>Matches Python _validate_roles from base_history.py (lines 90-128)
132+
*
133+
* @param role A single role string, List of roles, or null
134+
* @return List of valid role strings if role is provided, null otherwise
135+
* @throws IllegalArgumentException if role contains invalid values or is the wrong type
136+
*/
137+
@SuppressWarnings("unchecked")
138+
protected List<String> validateRoles(Object role) {
139+
if (role == null) {
140+
return null;
141+
}
142+
143+
// Handle single role string
144+
if (role instanceof String) {
145+
String roleStr = (String) role;
146+
if (!VALID_ROLES.contains(roleStr)) {
147+
throw new IllegalArgumentException(
148+
String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES));
149+
}
150+
return List.of(roleStr);
151+
}
152+
153+
// Handle list of roles
154+
if (role instanceof List) {
155+
List<?> roleList = (List<?>) role;
156+
157+
if (roleList.isEmpty()) {
158+
throw new IllegalArgumentException("roles cannot be empty");
159+
}
160+
161+
// Validate all roles in the list
162+
List<String> validatedRoles = new ArrayList<>();
163+
for (Object r : roleList) {
164+
if (!(r instanceof String)) {
165+
throw new IllegalArgumentException(
166+
"role list must contain only strings, found: " + r.getClass().getSimpleName());
167+
}
168+
String roleStr = (String) r;
169+
if (!VALID_ROLES.contains(roleStr)) {
170+
throw new IllegalArgumentException(
171+
String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES));
172+
}
173+
validatedRoles.add(roleStr);
174+
}
175+
176+
return validatedRoles;
177+
}
178+
179+
throw new IllegalArgumentException("role must be a String, List<String>, or null");
180+
}
181+
120182
public String getName() {
121183
return name;
122184
}

core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public void delete() {
7777
public void drop(String id) {
7878
if (id == null) {
7979
// Get the most recent message
80-
List<Map<String, Object>> recent = getRecent(1, false, true, null);
80+
List<Map<String, Object>> recent = getRecent(1, false, true, null, null);
8181
if (!recent.isEmpty()) {
8282
id = (String) recent.get(0).get(ID_FIELD_NAME);
8383
} else {
@@ -111,14 +111,31 @@ public List<Map<String, Object>> getMessages() {
111111
return formatContext(messages, false);
112112
}
113113

114+
/**
115+
* Retrieve the recent conversation history (backward-compatible overload without role filter).
116+
*
117+
* @param topK The number of previous messages to return
118+
* @param asText Whether to return as text strings or maps
119+
* @param raw Whether to return full Redis hash entries
120+
* @param sessionTag Session tag to filter by
121+
* @return List of messages
122+
*/
123+
public <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessionTag) {
124+
return getRecent(topK, asText, raw, sessionTag, null);
125+
}
126+
114127
@Override
115128
@SuppressWarnings("unchecked")
116-
public <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessionTag) {
129+
public <T> List<T> getRecent(
130+
int topK, boolean asText, boolean raw, String sessionTag, Object role) {
117131
// Validate topK
118132
if (topK < 0) {
119133
throw new IllegalArgumentException("topK must be an integer greater than or equal to 0");
120134
}
121135

136+
// Validate and normalize role parameter
137+
List<String> rolesToFilter = validateRoles(role);
138+
122139
List<String> returnFields =
123140
List.of(
124141
ID_FIELD_NAME,
@@ -131,9 +148,26 @@ public <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessi
131148
Filter sessionFilter =
132149
(sessionTag != null) ? Filter.tag(SESSION_FIELD_NAME, sessionTag) : defaultSessionFilter;
133150

151+
// Combine session filter with role filter if provided
152+
Filter filterExpression = sessionFilter;
153+
if (rolesToFilter != null) {
154+
if (rolesToFilter.size() == 1) {
155+
// Single role filter
156+
Filter roleFilter = Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(0));
157+
filterExpression = Filter.and(sessionFilter, roleFilter);
158+
} else {
159+
// Multiple roles - use OR logic
160+
Filter roleFilter = Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(0));
161+
for (int i = 1; i < rolesToFilter.size(); i++) {
162+
roleFilter = Filter.or(roleFilter, Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(i)));
163+
}
164+
filterExpression = Filter.and(sessionFilter, roleFilter);
165+
}
166+
}
167+
134168
FilterQuery query =
135169
FilterQuery.builder()
136-
.filterExpression(sessionFilter)
170+
.filterExpression(filterExpression)
137171
.returnFields(returnFields)
138172
.numResults(topK)
139173
.sortBy(TIMESTAMP_FIELD_NAME)

0 commit comments

Comments
 (0)