Skip to content

Commit 6152e77

Browse files
committed
feat(query): add multi-vector query support (#402)
Add support for searching across multiple vector fields simultaneously with weighted score combination, following Python RedisVL PR #402. New Classes: - Vector: Data container for vector information with validation - Stores vector, field name, dtype, and weight - Validates dtype against supported Redis types - Ensures positive weights and non-empty vectors - MultiVectorQuery: Enables multi-vector search with weighted scoring - Accepts single or multiple Vector objects - Builds VECTOR_RANGE queries with distance threshold 2.0 - Combines scores using formula: w_1*score_1 + w_2*score_2 + ... - Supports filter expressions and return fields - Returns params map with vector_0, vector_1, etc. Key Features: - Simultaneous search across multiple vector fields - Weighted score combination - Individual score calculations: (2 - distance_i)/2 - Flexible builder API (single, varargs, list) - Comprehensive validation - 20 unit tests covering all scenarios Implementation Details: - Distance threshold hardcoded at 2.0 (includes all eligible docs) - Vectors may differ in size and datatype - Must use cosine distance metric - Immutable Vector class with defensive copying - Query string format: @field:[VECTOR_RANGE 2.0 $vec]=>{$YIELD_DISTANCE_AS: dist} Python Reference: redis/redis-vl-python#402
1 parent 37e9324 commit 6152e77

File tree

3 files changed

+872
-0
lines changed

3 files changed

+872
-0
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
package com.redis.vl.query;
2+
3+
import com.redis.vl.utils.ArrayUtils;
4+
import java.util.*;
5+
import lombok.Getter;
6+
7+
/**
8+
* MultiVectorQuery allows for search over multiple vector fields in a document simultaneously.
9+
*
10+
* <p>The final score will be a weighted combination of the individual vector similarity scores
11+
* following the formula:
12+
*
13+
* <p>score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... )
14+
*
15+
* <p>Vectors may be of different size and datatype, but must be indexed using the 'cosine'
16+
* distance_metric.
17+
*
18+
* <p>Ported from Python: redisvl/query/aggregate.py:257-400 (MultiVectorQuery class)
19+
*
20+
* <p>Python equivalent:
21+
*
22+
* <pre>
23+
* from redisvl.query import MultiVectorQuery, Vector
24+
*
25+
* vector_1 = Vector(vector=[0.1, 0.2, 0.3], field_name="text_vector", dtype="float32", weight=0.7)
26+
* vector_2 = Vector(vector=[0.5, 0.5], field_name="image_vector", dtype="bfloat16", weight=0.2)
27+
*
28+
* query = MultiVectorQuery(
29+
* vectors=[vector_1, vector_2],
30+
* filter_expression=None,
31+
* num_results=10,
32+
* return_fields=["field1", "field2"],
33+
* dialect=2
34+
* )
35+
* </pre>
36+
*
37+
* Java equivalent:
38+
*
39+
* <pre>
40+
* Vector vector1 = Vector.builder()
41+
* .vector(new float[]{0.1f, 0.2f, 0.3f})
42+
* .fieldName("text_vector")
43+
* .dtype("float32")
44+
* .weight(0.7)
45+
* .build();
46+
*
47+
* Vector vector2 = Vector.builder()
48+
* .vector(new float[]{0.5f, 0.5f})
49+
* .fieldName("image_vector")
50+
* .dtype("bfloat16")
51+
* .weight(0.2)
52+
* .build();
53+
*
54+
* MultiVectorQuery query = MultiVectorQuery.builder()
55+
* .vectors(Arrays.asList(vector1, vector2))
56+
* .numResults(10)
57+
* .returnFields(Arrays.asList("field1", "field2"))
58+
* .build();
59+
* </pre>
60+
*/
61+
@Getter
62+
public final class MultiVectorQuery {
63+
64+
/** Distance threshold for VECTOR_RANGE (hardcoded at 2.0 to include all eligible documents) */
65+
private static final double DISTANCE_THRESHOLD = 2.0;
66+
67+
private final List<Vector> vectors;
68+
private final Filter filterExpression;
69+
private final List<String> returnFields;
70+
private final int numResults;
71+
private final int dialect;
72+
73+
private MultiVectorQuery(Builder builder) {
74+
// Validate before modifying state
75+
if (builder.vectors == null || builder.vectors.isEmpty()) {
76+
throw new IllegalArgumentException("At least one Vector is required");
77+
}
78+
79+
// Validate all elements are Vector objects
80+
for (Vector v : builder.vectors) {
81+
if (v == null) {
82+
throw new IllegalArgumentException("Vector list cannot contain null values");
83+
}
84+
}
85+
86+
this.vectors = List.copyOf(builder.vectors);
87+
this.filterExpression = builder.filterExpression;
88+
this.returnFields =
89+
builder.returnFields != null ? List.copyOf(builder.returnFields) : List.of();
90+
this.numResults = builder.numResults;
91+
this.dialect = builder.dialect;
92+
}
93+
94+
/**
95+
* Create a new Builder for MultiVectorQuery.
96+
*
97+
* @return A new Builder instance
98+
*/
99+
public static Builder builder() {
100+
return new Builder();
101+
}
102+
103+
/**
104+
* Build the Redis query string for multi-vector search.
105+
*
106+
* <p>Format: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} |
107+
* @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}}
108+
*
109+
* @return Query string
110+
*/
111+
public String toQueryString() {
112+
List<String> rangeQueries = new ArrayList<>();
113+
114+
for (int i = 0; i < vectors.size(); i++) {
115+
Vector v = vectors.get(i);
116+
String rangeQuery =
117+
String.format(
118+
"@%s:[VECTOR_RANGE %.1f $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}",
119+
v.getFieldName(), DISTANCE_THRESHOLD, i, i);
120+
rangeQueries.add(rangeQuery);
121+
}
122+
123+
String baseQuery = String.join(" | ", rangeQueries);
124+
125+
// Add filter expression if present
126+
if (filterExpression != null) {
127+
String filterStr = filterExpression.build();
128+
return String.format("(%s) AND (%s)", baseQuery, filterStr);
129+
}
130+
131+
return baseQuery;
132+
}
133+
134+
/**
135+
* Convert to parameter map for query execution.
136+
*
137+
* <p>Returns map with vector_0, vector_1, etc. as byte arrays
138+
*
139+
* @return Parameters map
140+
*/
141+
public Map<String, Object> toParams() {
142+
Map<String, Object> params = new HashMap<>();
143+
144+
for (int i = 0; i < vectors.size(); i++) {
145+
Vector v = vectors.get(i);
146+
byte[] vectorBytes = ArrayUtils.floatArrayToBytes(v.getVector());
147+
params.put(String.format("vector_%d", i), vectorBytes);
148+
}
149+
150+
return params;
151+
}
152+
153+
/**
154+
* Get the scoring formula for combining vector similarities.
155+
*
156+
* <p>Formula: w_1 * score_1 + w_2 * score_2 + ...
157+
*
158+
* <p>Where score_i = (2 - distance_i) / 2
159+
*
160+
* @return Scoring formula string
161+
*/
162+
public String getScoringFormula() {
163+
List<String> scoreTerms = new ArrayList<>();
164+
165+
for (int i = 0; i < vectors.size(); i++) {
166+
Vector v = vectors.get(i);
167+
scoreTerms.add(String.format("%.2f * score_%d", v.getWeight(), i));
168+
}
169+
170+
return String.join(" + ", scoreTerms);
171+
}
172+
173+
/**
174+
* Get individual score calculations.
175+
*
176+
* <p>Returns map of score_0=(2-distance_0)/2, score_1=(2-distance_1)/2, etc.
177+
*
178+
* @return Map of score names to calculation formulas
179+
*/
180+
public Map<String, String> getScoreCalculations() {
181+
Map<String, String> calculations = new LinkedHashMap<>();
182+
183+
for (int i = 0; i < vectors.size(); i++) {
184+
calculations.put(String.format("score_%d", i), String.format("(2 - distance_%d)/2", i));
185+
}
186+
187+
return calculations;
188+
}
189+
190+
@Override
191+
public String toString() {
192+
return toQueryString();
193+
}
194+
195+
/** Builder for creating MultiVectorQuery instances. */
196+
public static class Builder {
197+
private List<Vector> vectors;
198+
private Filter filterExpression;
199+
private List<String> returnFields;
200+
private int numResults = 10; // Default from Python
201+
private int dialect = 2; // Default from Python
202+
203+
Builder() {}
204+
205+
/**
206+
* Set the vectors to search (accepts a single Vector).
207+
*
208+
* @param vector Single Vector for search
209+
* @return This builder
210+
*/
211+
public Builder vector(Vector vector) {
212+
this.vectors = vector != null ? List.of(vector) : null;
213+
return this;
214+
}
215+
216+
/**
217+
* Set the vectors to search (accepts multiple Vectors as varargs).
218+
*
219+
* @param vectors Vectors for multi-vector search
220+
* @return This builder
221+
*/
222+
public Builder vectors(Vector... vectors) {
223+
this.vectors = vectors != null ? Arrays.asList(vectors) : null;
224+
return this;
225+
}
226+
227+
/**
228+
* Set the vectors to search (accepts a List of Vectors).
229+
*
230+
* @param vectors List of Vectors for multi-vector search
231+
* @return This builder
232+
*/
233+
public Builder vectors(List<Vector> vectors) {
234+
this.vectors = vectors != null ? new ArrayList<>(vectors) : null;
235+
return this;
236+
}
237+
238+
/**
239+
* Set the filter expression.
240+
*
241+
* @param filterExpression Filter to apply
242+
* @return This builder
243+
*/
244+
public Builder filterExpression(Filter filterExpression) {
245+
this.filterExpression = filterExpression;
246+
return this;
247+
}
248+
249+
/**
250+
* Set the fields to return in results (varargs).
251+
*
252+
* @param fields Field names to return
253+
* @return This builder
254+
*/
255+
public Builder returnFields(String... fields) {
256+
this.returnFields = Arrays.asList(fields);
257+
return this;
258+
}
259+
260+
/**
261+
* Set the fields to return in results (list).
262+
*
263+
* @param fields List of field names to return
264+
* @return This builder
265+
*/
266+
public Builder returnFields(List<String> fields) {
267+
this.returnFields = fields != null ? new ArrayList<>(fields) : null;
268+
return this;
269+
}
270+
271+
/**
272+
* Set the maximum number of results to return.
273+
*
274+
* @param numResults Maximum number of results
275+
* @return This builder
276+
*/
277+
public Builder numResults(int numResults) {
278+
this.numResults = numResults;
279+
return this;
280+
}
281+
282+
/**
283+
* Set the query dialect.
284+
*
285+
* @param dialect RediSearch dialect version
286+
* @return This builder
287+
*/
288+
public Builder dialect(int dialect) {
289+
this.dialect = dialect;
290+
return this;
291+
}
292+
293+
/**
294+
* Build the MultiVectorQuery instance.
295+
*
296+
* @return Configured MultiVectorQuery
297+
* @throws IllegalArgumentException if vectors is null/empty or contains null values
298+
*/
299+
public MultiVectorQuery build() {
300+
return new MultiVectorQuery(this);
301+
}
302+
}
303+
}

0 commit comments

Comments
 (0)