Skip to content

Commit 12829c6

Browse files
committed
Add custom retry transport
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent e0a3db6 commit 12829c6

File tree

1 file changed

+140
-1
lines changed

1 file changed

+140
-1
lines changed

replicate/client.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
import os
2+
import random
23
import re
3-
from typing import Any, Iterator, Optional, Union
4+
import time
5+
from datetime import datetime
6+
from typing import (
7+
Any,
8+
Iterable,
9+
Iterator,
10+
Mapping,
11+
Optional,
12+
Union,
13+
)
414

515
import httpx
616

@@ -46,6 +56,7 @@ def __init__(
4656
base_url=base_url,
4757
headers=headers,
4858
timeout=timeout,
59+
transport=RetryTransport(wrapped_transport=httpx.HTTPTransport()),
4960
)
5061

5162
def _build_client(self, **kwargs) -> httpx.Client:
@@ -103,3 +114,131 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
103114
if prediction.status == "failed":
104115
raise ModelError(prediction.error)
105116
return prediction.output
117+
118+
119+
# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
120+
class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
121+
"""A custom HTTP transport that automatically retries requests using an exponential backoff strategy
122+
for specific HTTP status codes and request methods.
123+
"""
124+
125+
RETRYABLE_METHODS = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"])
126+
RETRYABLE_STATUS_CODES = frozenset(
127+
[
128+
429, # Too Many Requests
129+
503, # Service Unavailable
130+
504, # Gateway Timeout
131+
]
132+
)
133+
MAX_BACKOFF_WAIT = 60
134+
135+
def __init__(
136+
self,
137+
wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport],
138+
max_attempts: int = 10,
139+
max_backoff_wait: float = MAX_BACKOFF_WAIT,
140+
backoff_factor: float = 0.1,
141+
jitter_ratio: float = 0.1,
142+
retryable_methods: Optional[Iterable[str]] = None,
143+
retry_status_codes: Optional[Iterable[int]] = None,
144+
) -> None:
145+
self._wrapped_transport = wrapped_transport
146+
147+
if jitter_ratio < 0 or jitter_ratio > 0.5:
148+
raise ValueError(
149+
f"jitter ratio should be between 0 and 0.5, actual {jitter_ratio}"
150+
)
151+
152+
self.max_attempts = max_attempts
153+
self.backoff_factor = backoff_factor
154+
self.retryable_methods = (
155+
frozenset(retryable_methods)
156+
if retryable_methods
157+
else self.RETRYABLE_METHODS
158+
)
159+
self.retry_status_codes = (
160+
frozenset(retry_status_codes)
161+
if retry_status_codes
162+
else self.RETRYABLE_STATUS_CODES
163+
)
164+
self.jitter_ratio = jitter_ratio
165+
self.max_backoff_wait = max_backoff_wait
166+
167+
def _calculate_sleep(
168+
self, attempts_made: int, headers: Union[httpx.Headers, Mapping[str, str]]
169+
) -> float:
170+
retry_after_header = (headers.get("Retry-After") or "").strip()
171+
if retry_after_header:
172+
if retry_after_header.isdigit():
173+
return float(retry_after_header)
174+
175+
try:
176+
parsed_date = datetime.fromisoformat(retry_after_header).astimezone()
177+
diff = (parsed_date - datetime.now().astimezone()).total_seconds()
178+
if diff > 0:
179+
return min(diff, self.max_backoff_wait)
180+
except ValueError:
181+
pass
182+
183+
backoff = self.backoff_factor * (2 ** (attempts_made - 1))
184+
jitter = (backoff * self.jitter_ratio) * random.choice([1, -1])
185+
total_backoff = backoff + jitter
186+
return min(total_backoff, self.max_backoff_wait)
187+
188+
def handle_request(self, request: httpx.Request) -> httpx.Response:
189+
response = self._wrapped_transport.handle_request(request) # type: ignore
190+
191+
if request.method not in self.retryable_methods:
192+
return response
193+
194+
remaining_attempts = self.max_attempts - 1
195+
attempts_made = 1
196+
197+
while True:
198+
if (
199+
remaining_attempts < 1
200+
or response.status_code not in self.retry_status_codes
201+
):
202+
return response
203+
204+
response.close()
205+
206+
sleep_for = self._calculate_sleep(attempts_made, response.headers)
207+
time.sleep(sleep_for)
208+
209+
response = self._wrapped_transport.handle_request(request) # type: ignore
210+
211+
attempts_made += 1
212+
remaining_attempts -= 1
213+
214+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
215+
response = await self._wrapped_transport.handle_async_request(request) # type: ignore
216+
217+
if request.method not in self.retryable_methods:
218+
return response
219+
220+
remaining_attempts = self.max_attempts - 1
221+
attempts_made = 1
222+
223+
while True:
224+
if (
225+
remaining_attempts < 1
226+
or response.status_code not in self.retry_status_codes
227+
):
228+
return response
229+
230+
response.close()
231+
232+
sleep_for = self._calculate_sleep(attempts_made, response.headers)
233+
time.sleep(sleep_for)
234+
235+
response = await self._wrapped_transport.handle_async_request(request) # type: ignore
236+
237+
attempts_made += 1
238+
remaining_attempts -= 1
239+
240+
async def aclose(self) -> None:
241+
await self._wrapped_transport.aclose() # type: ignore
242+
243+
def close(self) -> None:
244+
self._wrapped_transport.close() # type: ignore

0 commit comments

Comments
 (0)