Skip to content

Commit

Permalink
Cohere: added maxRetries
Browse files Browse the repository at this point in the history
  • Loading branch information
LangChain4j committed Feb 2, 2024
1 parent 2b461a9 commit b86f319
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.time.Duration.ofSeconds;
Expand All @@ -25,13 +26,15 @@ public class CohereScoringModel implements ScoringModel {

private final CohereClient client;
private final String modelName;
private final Integer maxRetries;

@Builder
public CohereScoringModel(
String baseUrl,
String apiKey,
String modelName,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Boolean logResponses
) {
Expand All @@ -43,6 +46,7 @@ public CohereScoringModel(
.logResponses(getOrDefault(logResponses, false))
.build();
this.modelName = modelName;
this.maxRetries = getOrDefault(maxRetries, 3);
}

public static CohereScoringModel withApiKey(String apiKey) {
Expand All @@ -60,7 +64,7 @@ public Response<List<Double>> scoreAll(List<TextSegment> segments, String query)
.collect(toList()))
.build();

RerankResponse response = client.rerank(request);
RerankResponse response = withRetry(() -> client.rerank(request), maxRetries);

List<Double> scores = response.getResults().stream()
.sorted(comparingInt(Result::getIndex))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ void should_score_multiple_segments_with_all_parameters() {
ScoringModel model = CohereScoringModel.builder()
.baseUrl("https://api.cohere.ai/v1/")
.apiKey(System.getenv("COHERE_API_KEY"))
.timeout(Duration.ofSeconds(30))
.modelName("rerank-multilingual-v2.0")
.timeout(Duration.ofSeconds(30))
.maxRetries(2)
.logRequests(true)
.logResponses(true)
.build();
Expand Down

0 comments on commit b86f319

Please sign in to comment.