|
| 1 | +package com.redis.vl.utils.rerank; |
| 2 | + |
| 3 | +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; |
| 4 | +import java.util.*; |
| 5 | +import java.util.stream.Collectors; |
| 6 | +import lombok.Builder; |
| 7 | + |
| 8 | +/** |
| 9 | + * Reranker that uses Cohere's Rerank API to rerank documents based on query relevance. |
| 10 | + * |
| 11 | + * <p>This reranker interacts with Cohere's /rerank API, requiring an API key for authentication. |
| 12 | + * The API key can be provided directly in the {@code apiConfig} Map or through the {@code |
| 13 | + * COHERE_API_KEY} environment variable. |
| 14 | + * |
| 15 | + * <p>Users must obtain an API key from <a |
| 16 | + * href="https://dashboard.cohere.com/">https://dashboard.cohere.com/</a>. Additionally, the {@code |
| 17 | + * com.cohere:cohere-java} library must be available on the classpath. |
| 18 | + * |
| 19 | + * <p>Example usage: |
| 20 | + * |
| 21 | + * <pre>{@code |
| 22 | + * // Initialize with API key |
| 23 | + * Map<String, String> apiConfig = Map.of("api_key", "your-api-key"); |
| 24 | + * CohereReranker reranker = CohereReranker.builder() |
| 25 | + * .apiConfig(apiConfig) |
| 26 | + * .limit(3) |
| 27 | + * .build(); |
| 28 | + * |
| 29 | + * // Rerank string documents |
| 30 | + * List<String> docs = Arrays.asList("doc1", "doc2", "doc3"); |
| 31 | + * RerankResult result = reranker.rank("query", docs); |
| 32 | + * |
| 33 | + * // Rerank dict documents with rankBy fields |
| 34 | + * List<Map<String, Object>> dictDocs = Arrays.asList( |
| 35 | + * Map.of("content", "doc1", "source", "wiki"), |
| 36 | + * Map.of("content", "doc2", "source", "textbook") |
| 37 | + * ); |
| 38 | + * CohereReranker reranker2 = CohereReranker.builder() |
| 39 | + * .apiConfig(apiConfig) |
| 40 | + * .rankBy(Arrays.asList("content", "source")) |
| 41 | + * .build(); |
| 42 | + * RerankResult result2 = reranker2.rank("query", dictDocs); |
| 43 | + * }</pre> |
| 44 | + * |
| 45 | + * @see <a href="https://docs.cohere.com/reference/rerank">Cohere Rerank API</a> |
| 46 | + */ |
| 47 | +@SuppressFBWarnings( |
| 48 | + value = "EI_EXPOSE_REP2", |
| 49 | + justification = |
| 50 | + "Lombok @Builder generates methods that store mutable objects, " |
| 51 | + + "but defensive copies are made in constructor") |
| 52 | +public class CohereReranker extends BaseReranker { |
| 53 | + |
| 54 | + private final Map<String, String> apiConfig; |
| 55 | + private volatile Object client; // com.cohere.api.Cohere (lazy loaded) |
| 56 | + |
| 57 | + /** |
| 58 | + * Create a CohereReranker using the builder. |
| 59 | + * |
| 60 | + * @param model The Cohere model to use (default: "rerank-english-v3.0") |
| 61 | + * @param rankBy List of fields to rank by for dict documents (optional) |
| 62 | + * @param limit Maximum number of results to return (default: 5) |
| 63 | + * @param returnScore Whether to return relevance scores (default: true) |
| 64 | + * @param apiConfig Map containing API configuration (must have "api_key" key) |
| 65 | + */ |
| 66 | + @Builder |
| 67 | + @SuppressFBWarnings( |
| 68 | + value = "EI_EXPOSE_REP2", |
| 69 | + justification = |
| 70 | + "Lombok builder generates methods that store mutable objects, " |
| 71 | + + "but defensive copies are made in constructor") |
| 72 | + private CohereReranker( |
| 73 | + String model, |
| 74 | + List<String> rankBy, |
| 75 | + Integer limit, |
| 76 | + Boolean returnScore, |
| 77 | + Map<String, String> apiConfig) { |
| 78 | + super( |
| 79 | + model != null ? model : "rerank-english-v3.0", |
| 80 | + rankBy, |
| 81 | + limit != null ? limit : 5, |
| 82 | + returnScore != null ? returnScore : true); |
| 83 | + |
| 84 | + // Make defensive copy of apiConfig to avoid EI2 SpotBugs warning |
| 85 | + this.apiConfig = |
| 86 | + apiConfig != null ? Collections.unmodifiableMap(new HashMap<>(apiConfig)) : null; |
| 87 | + } |
| 88 | + |
| 89 | + /** Initialize the Cohere client using the API key from apiConfig or environment. */ |
| 90 | + private synchronized void initializeClient() { |
| 91 | + if (client != null) { |
| 92 | + return; |
| 93 | + } |
| 94 | + |
| 95 | + // Check for Cohere SDK availability |
| 96 | + try { |
| 97 | + Class.forName("com.cohere.api.Cohere"); |
| 98 | + } catch (ClassNotFoundException e) { |
| 99 | + throw new IllegalStateException( |
| 100 | + "Cohere reranker requires the cohere-java library. " |
| 101 | + + "Please add dependency: com.cohere:cohere-java:1.8.1", |
| 102 | + e); |
| 103 | + } |
| 104 | + |
| 105 | + // Get API key from config or environment |
| 106 | + String apiKey = null; |
| 107 | + if (apiConfig != null && apiConfig.containsKey("api_key")) { |
| 108 | + apiKey = apiConfig.get("api_key"); |
| 109 | + } |
| 110 | + if (apiKey == null || apiKey.isEmpty()) { |
| 111 | + apiKey = System.getenv("COHERE_API_KEY"); |
| 112 | + } |
| 113 | + if (apiKey == null || apiKey.isEmpty()) { |
| 114 | + throw new IllegalArgumentException( |
| 115 | + "Cohere API key is required. " |
| 116 | + + "Provide it in apiConfig or set the COHERE_API_KEY environment variable."); |
| 117 | + } |
| 118 | + |
| 119 | + // Create Cohere client |
| 120 | + this.client = createCohereClient(apiKey); |
| 121 | + } |
| 122 | + |
| 123 | + /** |
| 124 | + * Create a Cohere client instance. |
| 125 | + * |
| 126 | + * @param apiKey The API key |
| 127 | + * @return Cohere client |
| 128 | + */ |
| 129 | + private Object createCohereClient(String apiKey) { |
| 130 | + try { |
| 131 | + // Import Cohere class |
| 132 | + Class<?> cohereClass = Class.forName("com.cohere.api.Cohere"); |
| 133 | + |
| 134 | + // Call Cohere.builder() |
| 135 | + Object builder = cohereClass.getMethod("builder").invoke(null); |
| 136 | + |
| 137 | + // Get builder class |
| 138 | + Class<?> builderClass = builder.getClass(); |
| 139 | + |
| 140 | + // Call .token(apiKey) |
| 141 | + builderClass.getMethod("token", String.class).invoke(builder, apiKey); |
| 142 | + |
| 143 | + // Call .clientName("redisvl4j") |
| 144 | + builderClass.getMethod("clientName", String.class).invoke(builder, "redisvl4j"); |
| 145 | + |
| 146 | + // Call .build() |
| 147 | + return builderClass.getMethod("build").invoke(builder); |
| 148 | + } catch (Exception e) { |
| 149 | + throw new IllegalStateException("Failed to create Cohere client", e); |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + /** |
| 154 | + * Rerank documents based on query relevance using Cohere's Rerank API. |
| 155 | + * |
| 156 | + * @param query The search query |
| 157 | + * @param docs List of documents (either List<String> or List<Map<String, |
| 158 | + * Object>>) |
| 159 | + * @return RerankResult with reranked documents and relevance scores |
| 160 | + * @throws IllegalArgumentException if query or docs are invalid |
| 161 | + */ |
| 162 | + @Override |
| 163 | + public RerankResult rank(String query, List<?> docs) { |
| 164 | + validateQuery(query); |
| 165 | + validateDocs(docs); |
| 166 | + |
| 167 | + if (docs.isEmpty()) { |
| 168 | + return new RerankResult(Collections.emptyList(), Collections.emptyList()); |
| 169 | + } |
| 170 | + |
| 171 | + // Lazy initialize client |
| 172 | + if (client == null) { |
| 173 | + initializeClient(); |
| 174 | + } |
| 175 | + |
| 176 | + // Determine if we're working with strings or dicts |
| 177 | + boolean isDictDocs = !docs.isEmpty() && docs.get(0) instanceof Map; |
| 178 | + |
| 179 | + if (isDictDocs && (rankBy == null || rankBy.isEmpty())) { |
| 180 | + throw new IllegalArgumentException( |
| 181 | + "If reranking dictionary-like docs, you must provide a list of rankBy fields"); |
| 182 | + } |
| 183 | + |
| 184 | + try { |
| 185 | + // Call Cohere rerank API |
| 186 | + Object response = callRerankApi(query, docs, isDictDocs); |
| 187 | + |
| 188 | + // Extract results |
| 189 | + return extractResults(docs, response); |
| 190 | + |
| 191 | + } catch (Exception e) { |
| 192 | + throw new RuntimeException("Failed to call Cohere rerank API", e); |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + /** |
| 197 | + * Call the Cohere rerank API. |
| 198 | + * |
| 199 | + * @param query The search query |
| 200 | + * @param docs Documents to rerank |
| 201 | + * @param isDictDocs Whether documents are Maps or Strings |
| 202 | + * @return Response from Cohere API |
| 203 | + * @throws Exception if API call fails |
| 204 | + */ |
| 205 | + private Object callRerankApi(String query, List<?> docs, boolean isDictDocs) throws Exception { |
| 206 | + // Get RerankRequest.builder() |
| 207 | + Class<?> rerankRequestClass = Class.forName("com.cohere.api.requests.RerankRequest"); |
| 208 | + Object requestBuilder = rerankRequestClass.getMethod("builder").invoke(null); |
| 209 | + |
| 210 | + // Get builder class (it's a staged builder, so we need to follow the stages) |
| 211 | + Class<?> currentStageClass = requestBuilder.getClass(); |
| 212 | + |
| 213 | + // Set query (QueryStage -> DocumentsStage) |
| 214 | + Object documentsStage = |
| 215 | + currentStageClass.getMethod("query", String.class).invoke(requestBuilder, query); |
| 216 | + |
| 217 | + // Convert docs to RerankRequestDocumentsItem list |
| 218 | + List<Object> documentItems = convertToDocumentItems(docs, isDictDocs); |
| 219 | + |
| 220 | + // Set documents (DocumentsStage -> _FinalStage) |
| 221 | + Class<?> documentsStageClass = documentsStage.getClass(); |
| 222 | + Object finalStage = |
| 223 | + documentsStageClass |
| 224 | + .getMethod("documents", List.class) |
| 225 | + .invoke(documentsStage, documentItems); |
| 226 | + |
| 227 | + // On the final stage, set optional parameters |
| 228 | + Class<?> finalStageClass = finalStage.getClass(); |
| 229 | + |
| 230 | + // Set model if not default |
| 231 | + finalStage = finalStageClass.getMethod("model", String.class).invoke(finalStage, model); |
| 232 | + |
| 233 | + // Set topN (limit) |
| 234 | + finalStage = finalStageClass.getMethod("topN", Integer.class).invoke(finalStage, limit); |
| 235 | + |
| 236 | + // Set rankFields for dict documents |
| 237 | + if (isDictDocs && rankBy != null && !rankBy.isEmpty()) { |
| 238 | + finalStage = finalStageClass.getMethod("rankFields", List.class).invoke(finalStage, rankBy); |
| 239 | + } |
| 240 | + |
| 241 | + // Build the request |
| 242 | + Object request = finalStageClass.getMethod("build").invoke(finalStage); |
| 243 | + |
| 244 | + // Call client.rerank(request) |
| 245 | + Class<?> cohereClass = client.getClass(); |
| 246 | + return cohereClass.getMethod("rerank", rerankRequestClass).invoke(client, request); |
| 247 | + } |
| 248 | + |
| 249 | + /** |
| 250 | + * Convert documents to RerankRequestDocumentsItem list. |
| 251 | + * |
| 252 | + * @param docs Documents |
| 253 | + * @param isDictDocs Whether documents are Maps or Strings |
| 254 | + * @return List of RerankRequestDocumentsItem |
| 255 | + * @throws Exception if conversion fails |
| 256 | + */ |
| 257 | + private List<Object> convertToDocumentItems(List<?> docs, boolean isDictDocs) throws Exception { |
| 258 | + Class<?> documentItemClass = Class.forName("com.cohere.api.types.RerankRequestDocumentsItem"); |
| 259 | + |
| 260 | + List<Object> result = new ArrayList<>(); |
| 261 | + for (Object doc : docs) { |
| 262 | + Object item; |
| 263 | + if (isDictDocs) { |
| 264 | + // Convert Map<String, Object> to Map<String, String> for Cohere API |
| 265 | + @SuppressWarnings("unchecked") |
| 266 | + Map<String, Object> docMap = (Map<String, Object>) doc; |
| 267 | + Map<String, String> stringMap = |
| 268 | + docMap.entrySet().stream() |
| 269 | + .collect(Collectors.toMap(Map.Entry::getKey, e -> String.valueOf(e.getValue()))); |
| 270 | + |
| 271 | + // Call RerankRequestDocumentsItem.of(Map<String, String>) |
| 272 | + item = documentItemClass.getMethod("of", Map.class).invoke(null, stringMap); |
| 273 | + } else { |
| 274 | + // Call RerankRequestDocumentsItem.of(String) |
| 275 | + item = documentItemClass.getMethod("of", String.class).invoke(null, doc); |
| 276 | + } |
| 277 | + result.add(item); |
| 278 | + } |
| 279 | + return result; |
| 280 | + } |
| 281 | + |
| 282 | + /** |
| 283 | + * Extract reranked results from Cohere API response. |
| 284 | + * |
| 285 | + * @param originalDocs Original documents |
| 286 | + * @param response Cohere API response |
| 287 | + * @return RerankResult with reranked documents and scores |
| 288 | + * @throws Exception if extraction fails |
| 289 | + */ |
| 290 | + private RerankResult extractResults(List<?> originalDocs, Object response) throws Exception { |
| 291 | + // Get results from response |
| 292 | + Class<?> responseClass = response.getClass(); |
| 293 | + List<?> results = (List<?>) responseClass.getMethod("getResults").invoke(response); |
| 294 | + |
| 295 | + List<Object> rerankedDocs = new ArrayList<>(); |
| 296 | + List<Double> scores = new ArrayList<>(); |
| 297 | + |
| 298 | + // Extract each result |
| 299 | + for (Object result : results) { |
| 300 | + Class<?> resultClass = result.getClass(); |
| 301 | + |
| 302 | + // Get index |
| 303 | + int index = (Integer) resultClass.getMethod("getIndex").invoke(result); |
| 304 | + |
| 305 | + // Get relevance score (float -> double) |
| 306 | + float scoreFloat = (Float) resultClass.getMethod("getRelevanceScore").invoke(result); |
| 307 | + double score = (double) scoreFloat; |
| 308 | + |
| 309 | + // Add to results |
| 310 | + rerankedDocs.add(originalDocs.get(index)); |
| 311 | + scores.add(score); |
| 312 | + } |
| 313 | + |
| 314 | + return new RerankResult(rerankedDocs, returnScore ? scores : null); |
| 315 | + } |
| 316 | +} |
0 commit comments