Skip to content

Commit 39ecf0a

Browse files
committed
feat(rerank): add Cohere reranker implementation
- Implement CohereReranker with reflection-based Cohere SDK integration - Add integration tests matching Python notebook scenarios - Add comprehensive unit tests for edge cases and validation - Update notebook with Cohere examples and API key loading - Add Cohere Java SDK dependency (compileOnly + test) Achieves feature parity with Python redis-vl Cohere reranker. Test results show perfect score matching (within floating-point precision).
1 parent 96b0d03 commit 39ecf0a

File tree

5 files changed

+927
-15
lines changed

5 files changed

+927
-15
lines changed

core/build.gradle.kts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,17 @@ dependencies {
6060
// HuggingFace tokenizers for all transformer models (BERT, XLMRoberta, etc)
6161
implementation("ai.djl.huggingface:tokenizers:0.30.0")
6262

63+
// Cohere Java SDK for reranking
64+
compileOnly("com.cohere:cohere-java:1.8.1")
65+
6366
// Test dependencies for LangChain4J (include in tests to verify integration)
6467
testImplementation("dev.langchain4j:langchain4j:0.36.2")
6568
testImplementation("dev.langchain4j:langchain4j-embeddings-all-minilm-l6-v2:0.36.2")
6669
testImplementation("dev.langchain4j:langchain4j-hugging-face:0.36.2")
6770

71+
// Cohere for integration tests
72+
testImplementation("com.cohere:cohere-java:1.8.1")
73+
6874
// Additional test dependencies
6975
testImplementation("com.squareup.okhttp3:mockwebserver:4.12.0")
7076
testImplementation("org.mockito:mockito-core:5.11.0")
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
package com.redis.vl.utils.rerank;
2+
3+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
4+
import java.util.*;
5+
import java.util.stream.Collectors;
6+
import lombok.Builder;
7+
8+
/**
9+
* Reranker that uses Cohere's Rerank API to rerank documents based on query relevance.
10+
*
11+
* <p>This reranker interacts with Cohere's /rerank API, requiring an API key for authentication.
12+
* The API key can be provided directly in the {@code apiConfig} Map or through the {@code
13+
* COHERE_API_KEY} environment variable.
14+
*
15+
* <p>Users must obtain an API key from <a
16+
* href="https://dashboard.cohere.com/">https://dashboard.cohere.com/</a>. Additionally, the {@code
17+
* com.cohere:cohere-java} library must be available on the classpath.
18+
*
19+
* <p>Example usage:
20+
*
21+
* <pre>{@code
22+
* // Initialize with API key
23+
* Map<String, String> apiConfig = Map.of("api_key", "your-api-key");
24+
* CohereReranker reranker = CohereReranker.builder()
25+
* .apiConfig(apiConfig)
26+
* .limit(3)
27+
* .build();
28+
*
29+
* // Rerank string documents
30+
* List<String> docs = Arrays.asList("doc1", "doc2", "doc3");
31+
* RerankResult result = reranker.rank("query", docs);
32+
*
33+
* // Rerank dict documents with rankBy fields
34+
* List<Map<String, Object>> dictDocs = Arrays.asList(
35+
* Map.of("content", "doc1", "source", "wiki"),
36+
* Map.of("content", "doc2", "source", "textbook")
37+
* );
38+
* CohereReranker reranker2 = CohereReranker.builder()
39+
* .apiConfig(apiConfig)
40+
* .rankBy(Arrays.asList("content", "source"))
41+
* .build();
42+
* RerankResult result2 = reranker2.rank("query", dictDocs);
43+
* }</pre>
44+
*
45+
* @see <a href="https://docs.cohere.com/reference/rerank">Cohere Rerank API</a>
46+
*/
47+
@SuppressFBWarnings(
48+
value = "EI_EXPOSE_REP2",
49+
justification =
50+
"Lombok @Builder generates methods that store mutable objects, "
51+
+ "but defensive copies are made in constructor")
52+
public class CohereReranker extends BaseReranker {
53+
54+
private final Map<String, String> apiConfig;
55+
private volatile Object client; // com.cohere.api.Cohere (lazy loaded)
56+
57+
/**
58+
* Create a CohereReranker using the builder.
59+
*
60+
* @param model The Cohere model to use (default: "rerank-english-v3.0")
61+
* @param rankBy List of fields to rank by for dict documents (optional)
62+
* @param limit Maximum number of results to return (default: 5)
63+
* @param returnScore Whether to return relevance scores (default: true)
64+
* @param apiConfig Map containing API configuration (must have "api_key" key)
65+
*/
66+
@Builder
67+
@SuppressFBWarnings(
68+
value = "EI_EXPOSE_REP2",
69+
justification =
70+
"Lombok builder generates methods that store mutable objects, "
71+
+ "but defensive copies are made in constructor")
72+
private CohereReranker(
73+
String model,
74+
List<String> rankBy,
75+
Integer limit,
76+
Boolean returnScore,
77+
Map<String, String> apiConfig) {
78+
super(
79+
model != null ? model : "rerank-english-v3.0",
80+
rankBy,
81+
limit != null ? limit : 5,
82+
returnScore != null ? returnScore : true);
83+
84+
// Make defensive copy of apiConfig to avoid EI2 SpotBugs warning
85+
this.apiConfig =
86+
apiConfig != null ? Collections.unmodifiableMap(new HashMap<>(apiConfig)) : null;
87+
}
88+
89+
/** Initialize the Cohere client using the API key from apiConfig or environment. */
90+
private synchronized void initializeClient() {
91+
if (client != null) {
92+
return;
93+
}
94+
95+
// Check for Cohere SDK availability
96+
try {
97+
Class.forName("com.cohere.api.Cohere");
98+
} catch (ClassNotFoundException e) {
99+
throw new IllegalStateException(
100+
"Cohere reranker requires the cohere-java library. "
101+
+ "Please add dependency: com.cohere:cohere-java:1.8.1",
102+
e);
103+
}
104+
105+
// Get API key from config or environment
106+
String apiKey = null;
107+
if (apiConfig != null && apiConfig.containsKey("api_key")) {
108+
apiKey = apiConfig.get("api_key");
109+
}
110+
if (apiKey == null || apiKey.isEmpty()) {
111+
apiKey = System.getenv("COHERE_API_KEY");
112+
}
113+
if (apiKey == null || apiKey.isEmpty()) {
114+
throw new IllegalArgumentException(
115+
"Cohere API key is required. "
116+
+ "Provide it in apiConfig or set the COHERE_API_KEY environment variable.");
117+
}
118+
119+
// Create Cohere client
120+
this.client = createCohereClient(apiKey);
121+
}
122+
123+
/**
124+
* Create a Cohere client instance.
125+
*
126+
* @param apiKey The API key
127+
* @return Cohere client
128+
*/
129+
private Object createCohereClient(String apiKey) {
130+
try {
131+
// Import Cohere class
132+
Class<?> cohereClass = Class.forName("com.cohere.api.Cohere");
133+
134+
// Call Cohere.builder()
135+
Object builder = cohereClass.getMethod("builder").invoke(null);
136+
137+
// Get builder class
138+
Class<?> builderClass = builder.getClass();
139+
140+
// Call .token(apiKey)
141+
builderClass.getMethod("token", String.class).invoke(builder, apiKey);
142+
143+
// Call .clientName("redisvl4j")
144+
builderClass.getMethod("clientName", String.class).invoke(builder, "redisvl4j");
145+
146+
// Call .build()
147+
return builderClass.getMethod("build").invoke(builder);
148+
} catch (Exception e) {
149+
throw new IllegalStateException("Failed to create Cohere client", e);
150+
}
151+
}
152+
153+
/**
154+
* Rerank documents based on query relevance using Cohere's Rerank API.
155+
*
156+
* @param query The search query
157+
* @param docs List of documents (either List&lt;String&gt; or List&lt;Map&lt;String,
158+
* Object&gt;&gt;)
159+
* @return RerankResult with reranked documents and relevance scores
160+
* @throws IllegalArgumentException if query or docs are invalid
161+
*/
162+
@Override
163+
public RerankResult rank(String query, List<?> docs) {
164+
validateQuery(query);
165+
validateDocs(docs);
166+
167+
if (docs.isEmpty()) {
168+
return new RerankResult(Collections.emptyList(), Collections.emptyList());
169+
}
170+
171+
// Lazy initialize client
172+
if (client == null) {
173+
initializeClient();
174+
}
175+
176+
// Determine if we're working with strings or dicts
177+
boolean isDictDocs = !docs.isEmpty() && docs.get(0) instanceof Map;
178+
179+
if (isDictDocs && (rankBy == null || rankBy.isEmpty())) {
180+
throw new IllegalArgumentException(
181+
"If reranking dictionary-like docs, you must provide a list of rankBy fields");
182+
}
183+
184+
try {
185+
// Call Cohere rerank API
186+
Object response = callRerankApi(query, docs, isDictDocs);
187+
188+
// Extract results
189+
return extractResults(docs, response);
190+
191+
} catch (Exception e) {
192+
throw new RuntimeException("Failed to call Cohere rerank API", e);
193+
}
194+
}
195+
196+
/**
197+
* Call the Cohere rerank API.
198+
*
199+
* @param query The search query
200+
* @param docs Documents to rerank
201+
* @param isDictDocs Whether documents are Maps or Strings
202+
* @return Response from Cohere API
203+
* @throws Exception if API call fails
204+
*/
205+
private Object callRerankApi(String query, List<?> docs, boolean isDictDocs) throws Exception {
206+
// Get RerankRequest.builder()
207+
Class<?> rerankRequestClass = Class.forName("com.cohere.api.requests.RerankRequest");
208+
Object requestBuilder = rerankRequestClass.getMethod("builder").invoke(null);
209+
210+
// Get builder class (it's a staged builder, so we need to follow the stages)
211+
Class<?> currentStageClass = requestBuilder.getClass();
212+
213+
// Set query (QueryStage -> DocumentsStage)
214+
Object documentsStage =
215+
currentStageClass.getMethod("query", String.class).invoke(requestBuilder, query);
216+
217+
// Convert docs to RerankRequestDocumentsItem list
218+
List<Object> documentItems = convertToDocumentItems(docs, isDictDocs);
219+
220+
// Set documents (DocumentsStage -> _FinalStage)
221+
Class<?> documentsStageClass = documentsStage.getClass();
222+
Object finalStage =
223+
documentsStageClass
224+
.getMethod("documents", List.class)
225+
.invoke(documentsStage, documentItems);
226+
227+
// On the final stage, set optional parameters
228+
Class<?> finalStageClass = finalStage.getClass();
229+
230+
// Set model if not default
231+
finalStage = finalStageClass.getMethod("model", String.class).invoke(finalStage, model);
232+
233+
// Set topN (limit)
234+
finalStage = finalStageClass.getMethod("topN", Integer.class).invoke(finalStage, limit);
235+
236+
// Set rankFields for dict documents
237+
if (isDictDocs && rankBy != null && !rankBy.isEmpty()) {
238+
finalStage = finalStageClass.getMethod("rankFields", List.class).invoke(finalStage, rankBy);
239+
}
240+
241+
// Build the request
242+
Object request = finalStageClass.getMethod("build").invoke(finalStage);
243+
244+
// Call client.rerank(request)
245+
Class<?> cohereClass = client.getClass();
246+
return cohereClass.getMethod("rerank", rerankRequestClass).invoke(client, request);
247+
}
248+
249+
/**
250+
* Convert documents to RerankRequestDocumentsItem list.
251+
*
252+
* @param docs Documents
253+
* @param isDictDocs Whether documents are Maps or Strings
254+
* @return List of RerankRequestDocumentsItem
255+
* @throws Exception if conversion fails
256+
*/
257+
private List<Object> convertToDocumentItems(List<?> docs, boolean isDictDocs) throws Exception {
258+
Class<?> documentItemClass = Class.forName("com.cohere.api.types.RerankRequestDocumentsItem");
259+
260+
List<Object> result = new ArrayList<>();
261+
for (Object doc : docs) {
262+
Object item;
263+
if (isDictDocs) {
264+
// Convert Map<String, Object> to Map<String, String> for Cohere API
265+
@SuppressWarnings("unchecked")
266+
Map<String, Object> docMap = (Map<String, Object>) doc;
267+
Map<String, String> stringMap =
268+
docMap.entrySet().stream()
269+
.collect(Collectors.toMap(Map.Entry::getKey, e -> String.valueOf(e.getValue())));
270+
271+
// Call RerankRequestDocumentsItem.of(Map<String, String>)
272+
item = documentItemClass.getMethod("of", Map.class).invoke(null, stringMap);
273+
} else {
274+
// Call RerankRequestDocumentsItem.of(String)
275+
item = documentItemClass.getMethod("of", String.class).invoke(null, doc);
276+
}
277+
result.add(item);
278+
}
279+
return result;
280+
}
281+
282+
/**
283+
* Extract reranked results from Cohere API response.
284+
*
285+
* @param originalDocs Original documents
286+
* @param response Cohere API response
287+
* @return RerankResult with reranked documents and scores
288+
* @throws Exception if extraction fails
289+
*/
290+
private RerankResult extractResults(List<?> originalDocs, Object response) throws Exception {
291+
// Get results from response
292+
Class<?> responseClass = response.getClass();
293+
List<?> results = (List<?>) responseClass.getMethod("getResults").invoke(response);
294+
295+
List<Object> rerankedDocs = new ArrayList<>();
296+
List<Double> scores = new ArrayList<>();
297+
298+
// Extract each result
299+
for (Object result : results) {
300+
Class<?> resultClass = result.getClass();
301+
302+
// Get index
303+
int index = (Integer) resultClass.getMethod("getIndex").invoke(result);
304+
305+
// Get relevance score (float -> double)
306+
float scoreFloat = (Float) resultClass.getMethod("getRelevanceScore").invoke(result);
307+
double score = (double) scoreFloat;
308+
309+
// Add to results
310+
rerankedDocs.add(originalDocs.get(index));
311+
scores.add(score);
312+
}
313+
314+
return new RerankResult(rerankedDocs, returnScore ? scores : null);
315+
}
316+
}

0 commit comments

Comments
 (0)