Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Base class for message history implementations.
Expand All @@ -15,6 +16,9 @@
*/
public abstract class BaseMessageHistory {

/** Valid role values for message filtering. */
private static final Set<String> VALID_ROLES = Set.of("system", "user", "llm", "tool");

protected final String name;
protected final String sessionTag;

Expand Down Expand Up @@ -55,10 +59,14 @@ protected BaseMessageHistory(String name, String sessionTag) {
* @param raw Whether to return the full Redis hash entry or just the role/content/tool_call_id.
* @param sessionTag Tag of the entries linked to a specific conversation session. Defaults to
* instance ULID.
* @param role Filter messages by role(s). Can be a single role string ("system", "user", "llm",
* "tool"), a List of role strings, or null for no filtering.
* @return List of messages (either as text strings or maps depending on asText parameter)
* @throws IllegalArgumentException if topK is not an integer greater than or equal to 0
* @throws IllegalArgumentException if topK is not an integer greater than or equal to 0, or if
* role contains invalid values
*/
public abstract <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessionTag);
public abstract <T> List<T> getRecent(
int topK, boolean asText, boolean raw, String sessionTag, Object role);

/**
* Insert a prompt:response pair into the message history.
Expand Down Expand Up @@ -117,6 +125,60 @@ protected <T> List<T> formatContext(List<Map<String, Object>> messages, boolean
return context;
}

/**
* Validate and normalize role parameter for filtering messages.
*
* <p>Matches Python _validate_roles from base_history.py (lines 90-128)
*
* @param role A single role string, List of roles, or null
* @return List of valid role strings if role is provided, null otherwise
* @throws IllegalArgumentException if role contains invalid values or is the wrong type
*/
@SuppressWarnings("unchecked")
protected List<String> validateRoles(Object role) {
if (role == null) {
return null;
}

// Handle single role string
if (role instanceof String) {
String roleStr = (String) role;
if (!VALID_ROLES.contains(roleStr)) {
throw new IllegalArgumentException(
String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES));
}
return List.of(roleStr);
}

// Handle list of roles
if (role instanceof List) {
List<?> roleList = (List<?>) role;

if (roleList.isEmpty()) {
throw new IllegalArgumentException("roles cannot be empty");
}

// Validate all roles in the list
List<String> validatedRoles = new ArrayList<>();
for (Object r : roleList) {
if (!(r instanceof String)) {
throw new IllegalArgumentException(
"role list must contain only strings, found: " + r.getClass().getSimpleName());
}
String roleStr = (String) r;
if (!VALID_ROLES.contains(roleStr)) {
throw new IllegalArgumentException(
String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES));
}
validatedRoles.add(roleStr);
}

return validatedRoles;
}

throw new IllegalArgumentException("role must be a String, List<String>, or null");
}

public String getName() {
return name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public void delete() {
public void drop(String id) {
if (id == null) {
// Get the most recent message
List<Map<String, Object>> recent = getRecent(1, false, true, null);
List<Map<String, Object>> recent = getRecent(1, false, true, null, null);
if (!recent.isEmpty()) {
id = (String) recent.get(0).get(ID_FIELD_NAME);
} else {
Expand Down Expand Up @@ -111,14 +111,31 @@ public List<Map<String, Object>> getMessages() {
return formatContext(messages, false);
}

/**
* Retrieve the recent conversation history (backward-compatible overload without role filter).
*
* @param topK The number of previous messages to return
* @param asText Whether to return as text strings or maps
* @param raw Whether to return full Redis hash entries
* @param sessionTag Session tag to filter by
* @return List of messages
*/
public <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessionTag) {
return getRecent(topK, asText, raw, sessionTag, null);
}

@Override
@SuppressWarnings("unchecked")
public <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessionTag) {
public <T> List<T> getRecent(
int topK, boolean asText, boolean raw, String sessionTag, Object role) {
// Validate topK
if (topK < 0) {
throw new IllegalArgumentException("topK must be an integer greater than or equal to 0");
}

// Validate and normalize role parameter
List<String> rolesToFilter = validateRoles(role);

List<String> returnFields =
List.of(
ID_FIELD_NAME,
Expand All @@ -131,9 +148,26 @@ public <T> List<T> getRecent(int topK, boolean asText, boolean raw, String sessi
Filter sessionFilter =
(sessionTag != null) ? Filter.tag(SESSION_FIELD_NAME, sessionTag) : defaultSessionFilter;

// Combine session filter with role filter if provided
Filter filterExpression = sessionFilter;
if (rolesToFilter != null) {
if (rolesToFilter.size() == 1) {
// Single role filter
Filter roleFilter = Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(0));
filterExpression = Filter.and(sessionFilter, roleFilter);
} else {
// Multiple roles - use OR logic
Filter roleFilter = Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(0));
for (int i = 1; i < rolesToFilter.size(); i++) {
roleFilter = Filter.or(roleFilter, Filter.tag(ROLE_FIELD_NAME, rolesToFilter.get(i)));
}
filterExpression = Filter.and(sessionFilter, roleFilter);
}
}

FilterQuery query =
FilterQuery.builder()
.filterExpression(sessionFilter)
.filterExpression(filterExpression)
.returnFields(returnFields)
.numResults(topK)
.sortBy(TIMESTAMP_FIELD_NAME)
Expand Down
Loading