Skip to content

Commit e39b6bc

Browse files
committed
feat(rerank): add VoyageAI reranker and enhance CohereReranker
CohereReranker enhancements: - Add runtime parameter overrides (limit, return_score, rank_by) - Add max_chunks_per_doc support for long documents - Add integration and unit tests for new features
1 parent 39ecf0a commit e39b6bc

File tree

9 files changed

+1423
-189
lines changed

9 files changed

+1423
-189
lines changed

core/src/main/java/com/redis/vl/utils/rerank/CohereReranker.java

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,23 @@ private Object createCohereClient(String apiKey) {
161161
*/
162162
@Override
163163
public RerankResult rank(String query, List<?> docs) {
164+
return rank(query, docs, Collections.emptyMap());
165+
}
166+
167+
/**
168+
* Rerank documents based on query relevance using Cohere's Rerank API with runtime parameter
169+
* overrides.
170+
*
171+
* @param query The search query
172+
* @param docs List of documents (either List&lt;String&gt; or List&lt;Map&lt;String,
173+
* Object&gt;&gt;)
174+
* @param kwargs Optional parameters to override defaults (limit, return_score, rank_by,
175+
* max_chunks_per_doc)
176+
* @return RerankResult with reranked documents and relevance scores
177+
* @throws IllegalArgumentException if query or docs are invalid
178+
*/
179+
@SuppressWarnings("unchecked")
180+
public RerankResult rank(String query, List<?> docs, Map<String, Object> kwargs) {
164181
validateQuery(query);
165182
validateDocs(docs);
166183

@@ -173,20 +190,37 @@ public RerankResult rank(String query, List<?> docs) {
173190
initializeClient();
174191
}
175192

193+
// Extract runtime parameters with defaults
194+
int effectiveLimit = (Integer) kwargs.getOrDefault("limit", this.limit);
195+
boolean effectiveReturnScore = (Boolean) kwargs.getOrDefault("return_score", this.returnScore);
196+
Object maxChunksPerDoc = kwargs.get("max_chunks_per_doc");
197+
198+
// Handle rank_by override
199+
List<String> effectiveRankBy = this.rankBy;
200+
if (kwargs.containsKey("rank_by")) {
201+
Object rankByValue = kwargs.get("rank_by");
202+
if (rankByValue instanceof List) {
203+
effectiveRankBy = (List<String>) rankByValue;
204+
} else if (rankByValue instanceof String) {
205+
effectiveRankBy = Collections.singletonList((String) rankByValue);
206+
}
207+
}
208+
176209
// Determine if we're working with strings or dicts
177210
boolean isDictDocs = !docs.isEmpty() && docs.get(0) instanceof Map;
178211

179-
if (isDictDocs && (rankBy == null || rankBy.isEmpty())) {
212+
if (isDictDocs && (effectiveRankBy == null || effectiveRankBy.isEmpty())) {
180213
throw new IllegalArgumentException(
181214
"If reranking dictionary-like docs, you must provide a list of rankBy fields");
182215
}
183216

184217
try {
185218
// Call Cohere rerank API
186-
Object response = callRerankApi(query, docs, isDictDocs);
219+
Object response =
220+
callRerankApi(query, docs, isDictDocs, effectiveLimit, effectiveRankBy, maxChunksPerDoc);
187221

188222
// Extract results
189-
return extractResults(docs, response);
223+
return extractResults(docs, response, effectiveReturnScore);
190224

191225
} catch (Exception e) {
192226
throw new RuntimeException("Failed to call Cohere rerank API", e);
@@ -199,10 +233,20 @@ public RerankResult rank(String query, List<?> docs) {
199233
* @param query The search query
200234
* @param docs Documents to rerank
201235
* @param isDictDocs Whether documents are Maps or Strings
236+
* @param limit Maximum number of results
237+
* @param rankBy Fields to rank by (for dict documents)
238+
* @param maxChunksPerDoc Maximum chunks per document (optional)
202239
* @return Response from Cohere API
203240
* @throws Exception if API call fails
204241
*/
205-
private Object callRerankApi(String query, List<?> docs, boolean isDictDocs) throws Exception {
242+
private Object callRerankApi(
243+
String query,
244+
List<?> docs,
245+
boolean isDictDocs,
246+
int limit,
247+
List<String> rankBy,
248+
Object maxChunksPerDoc)
249+
throws Exception {
206250
// Get RerankRequest.builder()
207251
Class<?> rerankRequestClass = Class.forName("com.cohere.api.requests.RerankRequest");
208252
Object requestBuilder = rerankRequestClass.getMethod("builder").invoke(null);
@@ -238,6 +282,14 @@ private Object callRerankApi(String query, List<?> docs, boolean isDictDocs) thr
238282
finalStage = finalStageClass.getMethod("rankFields", List.class).invoke(finalStage, rankBy);
239283
}
240284

285+
// Set maxChunksPerDoc if provided
286+
if (maxChunksPerDoc != null) {
287+
finalStage =
288+
finalStageClass
289+
.getMethod("maxChunksPerDoc", Integer.class)
290+
.invoke(finalStage, maxChunksPerDoc);
291+
}
292+
241293
// Build the request
242294
Object request = finalStageClass.getMethod("build").invoke(finalStage);
243295

@@ -284,10 +336,12 @@ private List<Object> convertToDocumentItems(List<?> docs, boolean isDictDocs) th
284336
*
285337
* @param originalDocs Original documents
286338
* @param response Cohere API response
339+
* @param returnScore Whether to return scores
287340
* @return RerankResult with reranked documents and scores
288341
* @throws Exception if extraction fails
289342
*/
290-
private RerankResult extractResults(List<?> originalDocs, Object response) throws Exception {
343+
private RerankResult extractResults(List<?> originalDocs, Object response, boolean returnScore)
344+
throws Exception {
291345
// Get results from response
292346
Class<?> responseClass = response.getClass();
293347
List<?> results = (List<?>) responseClass.getMethod("getResults").invoke(response);
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
package com.redis.vl.utils.rerank;
2+
3+
import com.fasterxml.jackson.databind.JsonNode;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
6+
import java.io.IOException;
7+
import java.util.*;
8+
import lombok.Builder;
9+
import okhttp3.*;
10+
11+
/**
12+
* Reranker that uses VoyageAI's Rerank API to rerank documents based on query relevance.
13+
*
14+
* <p>This reranker interacts with VoyageAI's /v1/rerank API, requiring an API key for
15+
* authentication. The API key can be provided directly in the {@code apiConfig} Map or through the
16+
* {@code VOYAGE_API_KEY} environment variable.
17+
*
18+
* <p>Users must obtain an API key from <a href="https://dash.voyageai.com/">VoyageAI Dashboard</a>.
19+
*
20+
* <p>Example usage:
21+
*
22+
* <pre>{@code
23+
* // Initialize with API key
24+
* Map<String, String> apiConfig = Map.of("api_key", "your-api-key");
25+
* VoyageAIReranker reranker = VoyageAIReranker.builder()
26+
* .model("rerank-lite-1")
27+
* .apiConfig(apiConfig)
28+
* .limit(3)
29+
* .build();
30+
*
31+
* // Rerank string documents
32+
* List<String> docs = Arrays.asList("doc1", "doc2", "doc3");
33+
* RerankResult result = reranker.rank("query", docs);
34+
* }</pre>
35+
*
36+
* @see <a href="https://docs.voyageai.com/docs/reranker">VoyageAI Rerank API</a>
37+
*/
38+
@SuppressFBWarnings(
39+
value = "EI_EXPOSE_REP2",
40+
justification =
41+
"Lombok @Builder generates methods that store mutable objects, "
42+
+ "but defensive copies are made in constructor")
43+
public class VoyageAIReranker extends BaseReranker {
44+
45+
private static final String API_ENDPOINT = "https://api.voyageai.com/v1/rerank";
46+
private static final MediaType JSON = MediaType.get("application/json; charset=utf-8");
47+
48+
private final Map<String, String> apiConfig;
49+
private final OkHttpClient httpClient;
50+
private final ObjectMapper objectMapper;
51+
52+
/**
53+
* Create a VoyageAIReranker using the builder.
54+
*
55+
* @param model The VoyageAI model to use (e.g., "rerank-lite-1", "rerank-2")
56+
* @param limit Maximum number of results to return (default: 5)
57+
* @param returnScore Whether to return relevance scores (default: true)
58+
* @param apiConfig Map containing API configuration (must have "api_key" key)
59+
*/
60+
@Builder
61+
@SuppressFBWarnings(
62+
value = "EI_EXPOSE_REP2",
63+
justification =
64+
"Lombok builder generates methods that store mutable objects, "
65+
+ "but defensive copies are made in constructor")
66+
private VoyageAIReranker(
67+
String model, Integer limit, Boolean returnScore, Map<String, String> apiConfig) {
68+
super(
69+
model != null ? model : "rerank-lite-1",
70+
null, // VoyageAI doesn't support rankBy
71+
limit != null ? limit : 5,
72+
returnScore != null ? returnScore : true);
73+
74+
// Make defensive copy of apiConfig
75+
this.apiConfig =
76+
apiConfig != null ? Collections.unmodifiableMap(new HashMap<>(apiConfig)) : null;
77+
78+
this.httpClient = new OkHttpClient();
79+
this.objectMapper = new ObjectMapper();
80+
}
81+
82+
/**
83+
* Rerank documents based on query relevance using VoyageAI's Rerank API.
84+
*
85+
* @param query The search query
86+
* @param docs List of documents (must be List&lt;String&gt;)
87+
* @return RerankResult with reranked documents and relevance scores
88+
* @throws IllegalArgumentException if query or docs are invalid
89+
*/
90+
@Override
91+
public RerankResult rank(String query, List<?> docs) {
92+
return rank(query, docs, Collections.emptyMap());
93+
}
94+
95+
/**
96+
* Rerank documents based on query relevance using VoyageAI's Rerank API with runtime parameter
97+
* overrides.
98+
*
99+
* @param query The search query
100+
* @param docs List of documents (must be List&lt;String&gt;)
101+
* @param kwargs Optional parameters to override defaults (limit, return_score, truncation)
102+
* @return RerankResult with reranked documents and relevance scores
103+
* @throws IllegalArgumentException if query or docs are invalid
104+
*/
105+
@SuppressWarnings("unchecked")
106+
public RerankResult rank(String query, List<?> docs, Map<String, Object> kwargs) {
107+
validateQuery(query);
108+
validateDocs(docs);
109+
110+
if (docs.isEmpty()) {
111+
return new RerankResult(Collections.emptyList(), Collections.emptyList());
112+
}
113+
114+
// Get API key from config or environment
115+
String apiKey = null;
116+
if (apiConfig != null && apiConfig.containsKey("api_key")) {
117+
apiKey = apiConfig.get("api_key");
118+
}
119+
if (apiKey == null || apiKey.isEmpty()) {
120+
apiKey = System.getenv("VOYAGE_API_KEY");
121+
}
122+
if (apiKey == null || apiKey.isEmpty()) {
123+
throw new IllegalArgumentException(
124+
"VoyageAI API key is required. "
125+
+ "Provide it in apiConfig or set the VOYAGE_API_KEY environment variable.");
126+
}
127+
128+
// Extract runtime parameters with defaults
129+
int effectiveLimit = (Integer) kwargs.getOrDefault("limit", this.limit);
130+
boolean effectiveReturnScore = (Boolean) kwargs.getOrDefault("return_score", this.returnScore);
131+
Object truncation = kwargs.get("truncation");
132+
133+
// VoyageAI only supports string documents
134+
List<String> stringDocs;
135+
if (docs.get(0) instanceof Map) {
136+
// Extract "content" field from dict docs
137+
stringDocs = new ArrayList<>();
138+
for (Object doc : docs) {
139+
Map<String, Object> docMap = (Map<String, Object>) doc;
140+
if (docMap.containsKey("content")) {
141+
stringDocs.add(String.valueOf(docMap.get("content")));
142+
} else {
143+
throw new IllegalArgumentException(
144+
"VoyageAI reranker requires documents to be strings or have a 'content' field");
145+
}
146+
}
147+
} else {
148+
stringDocs = (List<String>) docs;
149+
}
150+
151+
try {
152+
// Build request JSON
153+
Map<String, Object> requestBody = new HashMap<>();
154+
requestBody.put("query", query);
155+
requestBody.put("documents", stringDocs);
156+
requestBody.put("model", model);
157+
requestBody.put("top_k", effectiveLimit);
158+
if (truncation != null) {
159+
requestBody.put("truncation", truncation);
160+
}
161+
162+
String jsonBody = objectMapper.writeValueAsString(requestBody);
163+
164+
// Make HTTP request
165+
Request request =
166+
new Request.Builder()
167+
.url(API_ENDPOINT)
168+
.post(RequestBody.create(jsonBody, JSON))
169+
.addHeader("Authorization", "Bearer " + apiKey)
170+
.addHeader("Content-Type", "application/json")
171+
.build();
172+
173+
try (Response response = httpClient.newCall(request).execute()) {
174+
if (!response.isSuccessful()) {
175+
throw new RuntimeException(
176+
"VoyageAI API request failed: "
177+
+ response.code()
178+
+ " "
179+
+ (response.body() != null ? response.body().string() : ""));
180+
}
181+
182+
ResponseBody responseBody = response.body();
183+
if (responseBody == null) {
184+
throw new RuntimeException("VoyageAI API returned null response body");
185+
}
186+
187+
String responseBodyString = responseBody.string();
188+
JsonNode jsonResponse = objectMapper.readTree(responseBodyString);
189+
JsonNode results = jsonResponse.get("data");
190+
191+
List<Object> rerankedDocs = new ArrayList<>();
192+
List<Double> scores = new ArrayList<>();
193+
194+
for (JsonNode result : results) {
195+
int index = result.get("index").asInt();
196+
double score = result.get("relevance_score").asDouble();
197+
198+
rerankedDocs.add(docs.get(index));
199+
scores.add(score);
200+
}
201+
202+
return new RerankResult(rerankedDocs, effectiveReturnScore ? scores : null);
203+
}
204+
205+
} catch (IOException e) {
206+
throw new RuntimeException("Failed to call VoyageAI rerank API", e);
207+
}
208+
}
209+
}

0 commit comments

Comments
 (0)