|
| 1 | +package com.redis.vl.utils.rerank; |
| 2 | + |
| 3 | +import com.fasterxml.jackson.databind.JsonNode; |
| 4 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 5 | +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; |
| 6 | +import java.io.IOException; |
| 7 | +import java.util.*; |
| 8 | +import lombok.Builder; |
| 9 | +import okhttp3.*; |
| 10 | + |
| 11 | +/** |
| 12 | + * Reranker that uses VoyageAI's Rerank API to rerank documents based on query relevance. |
| 13 | + * |
| 14 | + * <p>This reranker interacts with VoyageAI's /v1/rerank API, requiring an API key for |
| 15 | + * authentication. The API key can be provided directly in the {@code apiConfig} Map or through the |
| 16 | + * {@code VOYAGE_API_KEY} environment variable. |
| 17 | + * |
| 18 | + * <p>Users must obtain an API key from <a href="https://dash.voyageai.com/">VoyageAI Dashboard</a>. |
| 19 | + * |
| 20 | + * <p>Example usage: |
| 21 | + * |
| 22 | + * <pre>{@code |
| 23 | + * // Initialize with API key |
| 24 | + * Map<String, String> apiConfig = Map.of("api_key", "your-api-key"); |
| 25 | + * VoyageAIReranker reranker = VoyageAIReranker.builder() |
| 26 | + * .model("rerank-lite-1") |
| 27 | + * .apiConfig(apiConfig) |
| 28 | + * .limit(3) |
| 29 | + * .build(); |
| 30 | + * |
| 31 | + * // Rerank string documents |
| 32 | + * List<String> docs = Arrays.asList("doc1", "doc2", "doc3"); |
| 33 | + * RerankResult result = reranker.rank("query", docs); |
| 34 | + * }</pre> |
| 35 | + * |
| 36 | + * @see <a href="https://docs.voyageai.com/docs/reranker">VoyageAI Rerank API</a> |
| 37 | + */ |
| 38 | +@SuppressFBWarnings( |
| 39 | + value = "EI_EXPOSE_REP2", |
| 40 | + justification = |
| 41 | + "Lombok @Builder generates methods that store mutable objects, " |
| 42 | + + "but defensive copies are made in constructor") |
| 43 | +public class VoyageAIReranker extends BaseReranker { |
| 44 | + |
| 45 | + private static final String API_ENDPOINT = "https://api.voyageai.com/v1/rerank"; |
| 46 | + private static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); |
| 47 | + |
| 48 | + private final Map<String, String> apiConfig; |
| 49 | + private final OkHttpClient httpClient; |
| 50 | + private final ObjectMapper objectMapper; |
| 51 | + |
| 52 | + /** |
| 53 | + * Create a VoyageAIReranker using the builder. |
| 54 | + * |
| 55 | + * @param model The VoyageAI model to use (e.g., "rerank-lite-1", "rerank-2") |
| 56 | + * @param limit Maximum number of results to return (default: 5) |
| 57 | + * @param returnScore Whether to return relevance scores (default: true) |
| 58 | + * @param apiConfig Map containing API configuration (must have "api_key" key) |
| 59 | + */ |
| 60 | + @Builder |
| 61 | + @SuppressFBWarnings( |
| 62 | + value = "EI_EXPOSE_REP2", |
| 63 | + justification = |
| 64 | + "Lombok builder generates methods that store mutable objects, " |
| 65 | + + "but defensive copies are made in constructor") |
| 66 | + private VoyageAIReranker( |
| 67 | + String model, Integer limit, Boolean returnScore, Map<String, String> apiConfig) { |
| 68 | + super( |
| 69 | + model != null ? model : "rerank-lite-1", |
| 70 | + null, // VoyageAI doesn't support rankBy |
| 71 | + limit != null ? limit : 5, |
| 72 | + returnScore != null ? returnScore : true); |
| 73 | + |
| 74 | + // Make defensive copy of apiConfig |
| 75 | + this.apiConfig = |
| 76 | + apiConfig != null ? Collections.unmodifiableMap(new HashMap<>(apiConfig)) : null; |
| 77 | + |
| 78 | + this.httpClient = new OkHttpClient(); |
| 79 | + this.objectMapper = new ObjectMapper(); |
| 80 | + } |
| 81 | + |
| 82 | + /** |
| 83 | + * Rerank documents based on query relevance using VoyageAI's Rerank API. |
| 84 | + * |
| 85 | + * @param query The search query |
| 86 | + * @param docs List of documents (must be List<String>) |
| 87 | + * @return RerankResult with reranked documents and relevance scores |
| 88 | + * @throws IllegalArgumentException if query or docs are invalid |
| 89 | + */ |
| 90 | + @Override |
| 91 | + public RerankResult rank(String query, List<?> docs) { |
| 92 | + return rank(query, docs, Collections.emptyMap()); |
| 93 | + } |
| 94 | + |
| 95 | + /** |
| 96 | + * Rerank documents based on query relevance using VoyageAI's Rerank API with runtime parameter |
| 97 | + * overrides. |
| 98 | + * |
| 99 | + * @param query The search query |
| 100 | + * @param docs List of documents (must be List<String>) |
| 101 | + * @param kwargs Optional parameters to override defaults (limit, return_score, truncation) |
| 102 | + * @return RerankResult with reranked documents and relevance scores |
| 103 | + * @throws IllegalArgumentException if query or docs are invalid |
| 104 | + */ |
| 105 | + @SuppressWarnings("unchecked") |
| 106 | + public RerankResult rank(String query, List<?> docs, Map<String, Object> kwargs) { |
| 107 | + validateQuery(query); |
| 108 | + validateDocs(docs); |
| 109 | + |
| 110 | + if (docs.isEmpty()) { |
| 111 | + return new RerankResult(Collections.emptyList(), Collections.emptyList()); |
| 112 | + } |
| 113 | + |
| 114 | + // Get API key from config or environment |
| 115 | + String apiKey = null; |
| 116 | + if (apiConfig != null && apiConfig.containsKey("api_key")) { |
| 117 | + apiKey = apiConfig.get("api_key"); |
| 118 | + } |
| 119 | + if (apiKey == null || apiKey.isEmpty()) { |
| 120 | + apiKey = System.getenv("VOYAGE_API_KEY"); |
| 121 | + } |
| 122 | + if (apiKey == null || apiKey.isEmpty()) { |
| 123 | + throw new IllegalArgumentException( |
| 124 | + "VoyageAI API key is required. " |
| 125 | + + "Provide it in apiConfig or set the VOYAGE_API_KEY environment variable."); |
| 126 | + } |
| 127 | + |
| 128 | + // Extract runtime parameters with defaults |
| 129 | + int effectiveLimit = (Integer) kwargs.getOrDefault("limit", this.limit); |
| 130 | + boolean effectiveReturnScore = (Boolean) kwargs.getOrDefault("return_score", this.returnScore); |
| 131 | + Object truncation = kwargs.get("truncation"); |
| 132 | + |
| 133 | + // VoyageAI only supports string documents |
| 134 | + List<String> stringDocs; |
| 135 | + if (docs.get(0) instanceof Map) { |
| 136 | + // Extract "content" field from dict docs |
| 137 | + stringDocs = new ArrayList<>(); |
| 138 | + for (Object doc : docs) { |
| 139 | + Map<String, Object> docMap = (Map<String, Object>) doc; |
| 140 | + if (docMap.containsKey("content")) { |
| 141 | + stringDocs.add(String.valueOf(docMap.get("content"))); |
| 142 | + } else { |
| 143 | + throw new IllegalArgumentException( |
| 144 | + "VoyageAI reranker requires documents to be strings or have a 'content' field"); |
| 145 | + } |
| 146 | + } |
| 147 | + } else { |
| 148 | + stringDocs = (List<String>) docs; |
| 149 | + } |
| 150 | + |
| 151 | + try { |
| 152 | + // Build request JSON |
| 153 | + Map<String, Object> requestBody = new HashMap<>(); |
| 154 | + requestBody.put("query", query); |
| 155 | + requestBody.put("documents", stringDocs); |
| 156 | + requestBody.put("model", model); |
| 157 | + requestBody.put("top_k", effectiveLimit); |
| 158 | + if (truncation != null) { |
| 159 | + requestBody.put("truncation", truncation); |
| 160 | + } |
| 161 | + |
| 162 | + String jsonBody = objectMapper.writeValueAsString(requestBody); |
| 163 | + |
| 164 | + // Make HTTP request |
| 165 | + Request request = |
| 166 | + new Request.Builder() |
| 167 | + .url(API_ENDPOINT) |
| 168 | + .post(RequestBody.create(jsonBody, JSON)) |
| 169 | + .addHeader("Authorization", "Bearer " + apiKey) |
| 170 | + .addHeader("Content-Type", "application/json") |
| 171 | + .build(); |
| 172 | + |
| 173 | + try (Response response = httpClient.newCall(request).execute()) { |
| 174 | + if (!response.isSuccessful()) { |
| 175 | + throw new RuntimeException( |
| 176 | + "VoyageAI API request failed: " |
| 177 | + + response.code() |
| 178 | + + " " |
| 179 | + + (response.body() != null ? response.body().string() : "")); |
| 180 | + } |
| 181 | + |
| 182 | + ResponseBody responseBody = response.body(); |
| 183 | + if (responseBody == null) { |
| 184 | + throw new RuntimeException("VoyageAI API returned null response body"); |
| 185 | + } |
| 186 | + |
| 187 | + String responseBodyString = responseBody.string(); |
| 188 | + JsonNode jsonResponse = objectMapper.readTree(responseBodyString); |
| 189 | + JsonNode results = jsonResponse.get("data"); |
| 190 | + |
| 191 | + List<Object> rerankedDocs = new ArrayList<>(); |
| 192 | + List<Double> scores = new ArrayList<>(); |
| 193 | + |
| 194 | + for (JsonNode result : results) { |
| 195 | + int index = result.get("index").asInt(); |
| 196 | + double score = result.get("relevance_score").asDouble(); |
| 197 | + |
| 198 | + rerankedDocs.add(docs.get(index)); |
| 199 | + scores.add(score); |
| 200 | + } |
| 201 | + |
| 202 | + return new RerankResult(rerankedDocs, effectiveReturnScore ? scores : null); |
| 203 | + } |
| 204 | + |
| 205 | + } catch (IOException e) { |
| 206 | + throw new RuntimeException("Failed to call VoyageAI rerank API", e); |
| 207 | + } |
| 208 | + } |
| 209 | +} |
0 commit comments