diff --git a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereScoringModel.java b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereScoringModel.java index 7f8fe96195a..29175d9c338 100644 --- a/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereScoringModel.java +++ b/langchain4j-cohere/src/main/java/dev/langchain4j/model/cohere/CohereScoringModel.java @@ -9,6 +9,7 @@ import java.time.Duration; import java.util.List; +import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static java.time.Duration.ofSeconds; @@ -25,6 +26,7 @@ public class CohereScoringModel implements ScoringModel { private final CohereClient client; private final String modelName; + private final Integer maxRetries; @Builder public CohereScoringModel( @@ -32,6 +34,7 @@ public CohereScoringModel( String apiKey, String modelName, Duration timeout, + Integer maxRetries, Boolean logRequests, Boolean logResponses ) { @@ -43,6 +46,7 @@ public CohereScoringModel( .logResponses(getOrDefault(logResponses, false)) .build(); this.modelName = modelName; + this.maxRetries = getOrDefault(maxRetries, 3); } public static CohereScoringModel withApiKey(String apiKey) { @@ -60,7 +64,7 @@ public Response> scoreAll(List segments, String query) .collect(toList())) .build(); - RerankResponse response = client.rerank(request); + RerankResponse response = withRetry(() -> client.rerank(request), maxRetries); List scores = response.getResults().stream() .sorted(comparingInt(Result::getIndex)) diff --git a/langchain4j-cohere/src/test/java/dev/langchain4j/model/cohere/CohereScoringModelIT.java b/langchain4j-cohere/src/test/java/dev/langchain4j/model/cohere/CohereScoringModelIT.java index de220478d61..88432aab056 100644 --- a/langchain4j-cohere/src/test/java/dev/langchain4j/model/cohere/CohereScoringModelIT.java +++ b/langchain4j-cohere/src/test/java/dev/langchain4j/model/cohere/CohereScoringModelIT.java @@ -41,8 +41,9 @@ void should_score_multiple_segments_with_all_parameters() { ScoringModel model = CohereScoringModel.builder() .baseUrl("https://api.cohere.ai/v1/") .apiKey(System.getenv("COHERE_API_KEY")) - .timeout(Duration.ofSeconds(30)) .modelName("rerank-multilingual-v2.0") + .timeout(Duration.ofSeconds(30)) + .maxRetries(2) .logRequests(true) .logResponses(true) .build();