Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 221 additions & 79 deletions core/src/main/java/com/redis/vl/query/TextQuery.java
Original file line number Diff line number Diff line change
@@ -1,122 +1,264 @@
package com.redis.vl.query;

import java.util.ArrayList;
import java.util.List;
import com.redis.vl.utils.TokenEscaper;
import java.util.HashMap;
import java.util.Map;
import lombok.Getter;

/** Full-text search query */
/**
* Full-text search query with support for field weights.
*
* <p>Python port: Implements text_field_name with Union[str, Dict[str, float]] for weighted text
* search across multiple fields.
*
* <p>Example usage:
*
* <pre>{@code
* // Single field (backward compatible)
* TextQuery query = TextQuery.builder()
* .text("search terms")
* .textField("description")
* .build();
*
* // Multiple fields with weights
* TextQuery query = TextQuery.builder()
* .text("search terms")
* .textFieldWeights(Map.of("title", 5.0, "content", 2.0, "tags", 1.0))
* .build();
* }</pre>
*/
@Getter
public class TextQuery {

private final String text;
private final String textField;
private final String scorer;
private final Filter filterExpression;
private final List<String> returnFields;
private final Integer numResults;

/**
* Create a text query without a filter expression.
*
* @param text The text to search for
* @param textField The field to search in
* @param scorer The scoring algorithm (e.g., "BM25", "TF IDF")
* @param returnFields List of fields to return in results
*/
public TextQuery(String text, String textField, String scorer, List<String> returnFields) {
this(text, textField, scorer, null, returnFields);
}
/** Field names mapped to their search weights */
private Map<String, Double> fieldWeights;

/**
* Create a text query with all parameters.
*
* @param text The text to search for
* @param textField The field to search in
* @param scorer The scoring algorithm
* @param filterExpression Optional filter to apply
* @param returnFields List of fields to return in results
*/
public TextQuery(
String text,
String textField,
String scorer,
Filter filterExpression,
List<String> returnFields) {
this.text = text;
this.textField = textField;
this.scorer = scorer;
this.filterExpression = filterExpression;
this.returnFields = returnFields != null ? new ArrayList<>(returnFields) : null;
private TextQuery(Builder builder) {
this.text = builder.text;
this.scorer = builder.scorer;
this.filterExpression = builder.filterExpression;
this.numResults = builder.numResults;
this.fieldWeights = new HashMap<>(builder.fieldWeights);
}

/**
* Get the search text
* Update the field weights dynamically.
*
* @return Search text
* @param fieldWeights Map of field names to weights
*/
public String getText() {
return text;
public void setFieldWeights(Map<String, Double> fieldWeights) {
validateFieldWeights(fieldWeights);
this.fieldWeights = new HashMap<>(fieldWeights);
}

/**
* Get the text field to search in
* Get a copy of the field weights.
*
* @return Text field name
* @return Map of field names to weights
*/
public String getTextField() {
return textField;
public Map<String, Double> getFieldWeights() {
return new HashMap<>(fieldWeights);
}

/**
* Get the scoring algorithm
* Build the Redis query string for text search with weighted fields.
*
* <p>Format:
*
* @return Scorer name
* <ul>
* <li>Single field default weight: {@code @field:(term1 | term2)}
* <li>Single field with weight: {@code @field:(term1 | term2) => { $weight: 5.0 }}
* <li>Multiple fields: {@code (@field1:(terms) => { $weight: 3.0 } | @field2:(terms) => {
* $weight: 2.0 })}
* </ul>
*
* @return Redis query string
*/
public String getScorer() {
return scorer;
public String toQueryString() {
TokenEscaper escaper = new TokenEscaper();

// Tokenize and escape the query text
String[] tokens = text.split("\\s+");
StringBuilder escapedQuery = new StringBuilder();

for (int i = 0; i < tokens.length; i++) {
if (i > 0) {
escapedQuery.append(" | ");
}
String cleanToken =
tokens[i].strip().stripLeading().stripTrailing().replace(",", "").toLowerCase();
escapedQuery.append(escaper.escape(cleanToken));
}

String escapedText = escapedQuery.toString();

// Build query parts for each field with its weight
StringBuilder queryBuilder = new StringBuilder();
int fieldCount = 0;

for (Map.Entry<String, Double> entry : fieldWeights.entrySet()) {
String field = entry.getKey();
Double weight = entry.getValue();

if (fieldCount > 0) {
queryBuilder.append(" | ");
}

queryBuilder.append("@").append(field).append(":(").append(escapedText).append(")");

// Add weight modifier if not default
if (weight != 1.0) {
queryBuilder.append(" => { $weight: ").append(weight).append(" }");
}

fieldCount++;
}

// Wrap multiple fields in parentheses
String textQuery;
if (fieldWeights.size() > 1) {
textQuery = "(" + queryBuilder.toString() + ")";
} else {
textQuery = queryBuilder.toString();
}

// Add filter expression if present
if (filterExpression != null) {
return textQuery + " AND " + filterExpression.build();
}

return textQuery;
}

/**
* Get the filter expression
*
* @return Filter expression or null
*/
public Filter getFilterExpression() {
return filterExpression;
@Override
public String toString() {
return toQueryString();
}

/**
* Get the return fields
* Create a new Builder for TextQuery.
*
* @return List of fields to return or null
* @return Builder instance
*/
public List<String> getReturnFields() {
return returnFields != null ? new ArrayList<>(returnFields) : null;
public static Builder builder() {
return new Builder();
}

/**
* Build the query string for Redis text search
*
* @return Query string
*/
public String toQueryString() {
StringBuilder query = new StringBuilder();
/** Builder for TextQuery with support for field weights. */
public static class Builder {
private String text;
private String scorer = "BM25STD";
private Filter filterExpression;
private Integer numResults = 10;
private Map<String, Double> fieldWeights = new HashMap<>();

// Add filter expression if present
if (filterExpression != null) {
query.append(filterExpression.build()).append(" ");
/**
* Set the text to search for.
*
* @param text Search text
* @return Builder
*/
public Builder text(String text) {
this.text = text;
return this;
}

// Add text search
if (textField != null && !textField.isEmpty()) {
query.append("@").append(textField).append(":(").append(text).append(")");
} else {
// Search all text fields
query.append(text);
/**
* Set a single text field to search (backward compatible).
*
* @param fieldName Field name
* @return Builder
*/
public Builder textField(String fieldName) {
this.fieldWeights = Map.of(fieldName, 1.0);
return this;
}

return query.toString();
/**
* Set multiple text fields with weights.
*
* @param fieldWeights Map of field names to weights
* @return Builder
*/
public Builder textFieldWeights(Map<String, Double> fieldWeights) {
validateFieldWeights(fieldWeights);
this.fieldWeights = new HashMap<>(fieldWeights);
return this;
}

/**
* Set the scoring algorithm.
*
* @param scorer Scorer name (e.g., BM25STD, TFIDF)
* @return Builder
*/
public Builder scorer(String scorer) {
this.scorer = scorer;
return this;
}

/**
* Set the filter expression.
*
* @param filterExpression Filter to apply
* @return Builder
*/
public Builder filterExpression(Filter filterExpression) {
this.filterExpression = filterExpression;
return this;
}

/**
* Set the number of results to return.
*
* @param numResults Number of results
* @return Builder
*/
public Builder numResults(int numResults) {
this.numResults = numResults;
return this;
}

/**
* Build the TextQuery instance.
*
* @return TextQuery
* @throws IllegalArgumentException if text is null or field weights are empty
*/
public TextQuery build() {
if (text == null || text.trim().isEmpty()) {
throw new IllegalArgumentException("Text cannot be null or empty");
}
if (fieldWeights.isEmpty()) {
throw new IllegalArgumentException("At least one text field must be specified");
}
return new TextQuery(this);
}
}

@Override
public String toString() {
return toQueryString();
/**
* Validate field weights.
*
* @param fieldWeights Map to validate
* @throws IllegalArgumentException if weights are invalid
*/
private static void validateFieldWeights(Map<String, Double> fieldWeights) {
for (Map.Entry<String, Double> entry : fieldWeights.entrySet()) {
String field = entry.getKey();
Double weight = entry.getValue();

if (weight == null) {
throw new IllegalArgumentException("Weight for field '" + field + "' cannot be null");
}
if (weight <= 0) {
throw new IllegalArgumentException(
"Weight for field '" + field + "' must be positive, got " + weight);
}
}
}
}
12 changes: 10 additions & 2 deletions core/src/test/java/com/redis/vl/query/QueryIntegrationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,8 @@ void testTextQuery() {
List<String> scorers = Arrays.asList("BM25", "TFIDF", "TFIDF.DOCNORM", "DISMAX", "DOCSCORE");

for (String scorer : scorers) {
TextQuery textQuery = new TextQuery(text, textField, scorer, returnFields);
TextQuery textQuery =
TextQuery.builder().text(text).textField(textField).scorer(scorer).numResults(10).build();
List<Map<String, Object>> results = index.query(textQuery);
assertThat(results)
.as("TextQuery with scorer " + scorer + " should return results")
Expand Down Expand Up @@ -630,7 +631,14 @@ void testTextQueryWithFilter() {
Filter.and(Filter.tag("credit_score", "high"), Filter.numeric("age").gt(30));
String scorer = "TFIDF";

TextQuery textQuery = new TextQuery(text, textField, scorer, filterExpression, returnFields);
TextQuery textQuery =
TextQuery.builder()
.text(text)
.textField(textField)
.scorer(scorer)
.filterExpression(filterExpression)
.numResults(10)
.build();
List<Map<String, Object>> results = index.query(textQuery);

assertThat(results).hasSize(2); // mary and derrick
Expand Down
Loading