Skip to content

Commit

Permalink
improved retry mechanism (langchain4j#32)
Browse files Browse the repository at this point in the history
- do not retry in case of 401
- wait for 1 second before retrying in case of 429
  • Loading branch information
langchain4j authored Jul 17, 2023
1 parent b202826 commit ca96e99
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 3 deletions.
2 changes: 1 addition & 1 deletion langchain4j-parent/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
<dependency>
<groupId>dev.ai4j</groupId>
<artifactId>openai4j</artifactId>
<version>0.5.1</version>
<version>0.5.2</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.internal;

import dev.ai4j.openai4j.OpenAiHttpException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -9,11 +10,16 @@

public class RetryUtils {

private static final int HTTP_CODE_401_UNAUTHORIZED = 401;
private static final int HTTP_CODE_429_TOO_MANY_REQUESTS = 429;

private static final Logger log = LoggerFactory.getLogger(RetryUtils.class);

/**
* This method attempts to execute a given action up to a specified number of times.
* If the action fails on all attempts, it throws a RuntimeException.
* Retry will not happen for 401 (Unauthorized).
* Retry will happen after 1-second delay for 429 (Too many requests).
*
* @param action The action to be executed.
* @param maxAttempts The maximum number of attempts to execute the action.
Expand All @@ -24,11 +30,29 @@ public static <T> T withRetry(Callable<T> action, int maxAttempts) {
for (int attempt = 1; attempt <= maxAttempts; attempt++) {
try {
return action.call();
} catch (OpenAiHttpException e) {
if (attempt == maxAttempts) {
throw new RuntimeException(e);
}

if (e.code() == HTTP_CODE_401_UNAUTHORIZED) {
throw new RuntimeException(e); // makes no sense to retry
}

log.warn(format("Exception was thrown on attempt %s of %s", attempt, maxAttempts), e);

if (e.code() == HTTP_CODE_429_TOO_MANY_REQUESTS) {
try {
// TODO make configurable or read from Retry-After
Thread.sleep(1000); // makes sense to retry after a bit of waiting
} catch (InterruptedException ignored) {
}
}
} catch (Exception e) {
if (attempt == maxAttempts) {
throw new RuntimeException(e);
}
log.warn(format("Exception was thrown on attempt %s of %s", action, maxAttempts), e);
log.warn(format("Exception was thrown on attempt %s of %s", attempt, maxAttempts), e);
}
}
throw new RuntimeException("Failed after " + maxAttempts + " attempts");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.internal;

import dev.ai4j.openai4j.OpenAiHttpException;
import org.junit.jupiter.api.Test;

import java.util.concurrent.Callable;
Expand All @@ -20,17 +21,21 @@ void testSuccessfulCall() throws Exception {

assertThat(result).isEqualTo("Success");
verify(mockAction).call();
verifyNoMoreInteractions(mockAction);
}

@Test
void testRetryThenSuccess() throws Exception {
Callable<String> mockAction = mock(Callable.class);
when(mockAction.call()).thenThrow(new RuntimeException()).thenReturn("Success");
when(mockAction.call())
.thenThrow(new RuntimeException())
.thenReturn("Success");

String result = withRetry(mockAction, 3);

assertThat(result).isEqualTo("Success");
verify(mockAction, times(2)).call();
verifyNoMoreInteractions(mockAction);
}

@Test
Expand All @@ -41,5 +46,41 @@ void testMaxAttemptsReached() throws Exception {
assertThatThrownBy(() -> withRetry(mockAction, 3))
.isInstanceOf(RuntimeException.class);
verify(mockAction, times(3)).call();
verifyNoMoreInteractions(mockAction);
}

@Test
void should_not_retry_401_unauthorized() throws Exception {
Callable<String> mockAction = mock(Callable.class);
when(mockAction.call()).thenThrow(new OpenAiHttpException(401, "Unauthorized"));

assertThatThrownBy(() -> withRetry(mockAction, 3))
.isInstanceOf(RuntimeException.class)
.hasCauseInstanceOf(OpenAiHttpException.class)
.hasRootCauseMessage("Unauthorized");

verify(mockAction).call();
verifyNoMoreInteractions(mockAction);
}

@Test
void should_wait_1_second_before_retry_when_429_too_many_requests() throws Exception {
Callable<String> mockAction = mock(Callable.class);
when(mockAction.call())
.thenThrow(new OpenAiHttpException(429, "Too many requests"))
.thenReturn("Success");

long startTime = System.currentTimeMillis();

String result = withRetry(mockAction, 3);

long endTime = System.currentTimeMillis();
long duration = endTime - startTime;

assertThat(result).isEqualTo("Success");
verify(mockAction, times(2)).call();
verifyNoMoreInteractions(mockAction);

assertThat(duration).isGreaterThanOrEqualTo(1000);
}
}

0 comments on commit ca96e99

Please sign in to comment.