Skip to content

Commit ff51f4d

Browse files
committed
Add custom retry transport
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 38453db commit ff51f4d

File tree

1 file changed

+140
-2
lines changed

1 file changed

+140
-2
lines changed

replicate/client.py

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
import os
2+
import random
23
import re
3-
from typing import Any, Iterator, Optional, Union
4+
import time
5+
from datetime import datetime
6+
from typing import (
7+
Any,
8+
Iterable,
9+
Iterator,
10+
Mapping,
11+
Optional,
12+
Union,
13+
)
414

515
import httpx
616

@@ -49,7 +59,7 @@ def __init__(
4959
base_url=base_url,
5060
headers=headers,
5161
timeout=timeout,
52-
transport=transport,
62+
transport=RetryTransport(wrapped_transport=transport),
5363
)
5464

5565
def _build_client(self, **kwargs) -> httpx.Client:
@@ -111,3 +121,131 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
111121
if prediction.status == "failed":
112122
raise ModelError(prediction.error)
113123
return prediction.output
124+
125+
126+
# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
127+
class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
128+
"""A custom HTTP transport that automatically retries requests using an exponential backoff strategy
129+
for specific HTTP status codes and request methods.
130+
"""
131+
132+
RETRYABLE_METHODS = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"])
133+
RETRYABLE_STATUS_CODES = frozenset(
134+
[
135+
429, # Too Many Requests
136+
503, # Service Unavailable
137+
504, # Gateway Timeout
138+
]
139+
)
140+
MAX_BACKOFF_WAIT = 60
141+
142+
def __init__(
143+
self,
144+
wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport],
145+
max_attempts: int = 10,
146+
max_backoff_wait: float = MAX_BACKOFF_WAIT,
147+
backoff_factor: float = 0.1,
148+
jitter_ratio: float = 0.1,
149+
retryable_methods: Optional[Iterable[str]] = None,
150+
retry_status_codes: Optional[Iterable[int]] = None,
151+
) -> None:
152+
self._wrapped_transport = wrapped_transport
153+
154+
if jitter_ratio < 0 or jitter_ratio > 0.5:
155+
raise ValueError(
156+
f"jitter ratio should be between 0 and 0.5, actual {jitter_ratio}"
157+
)
158+
159+
self.max_attempts = max_attempts
160+
self.backoff_factor = backoff_factor
161+
self.retryable_methods = (
162+
frozenset(retryable_methods)
163+
if retryable_methods
164+
else self.RETRYABLE_METHODS
165+
)
166+
self.retry_status_codes = (
167+
frozenset(retry_status_codes)
168+
if retry_status_codes
169+
else self.RETRYABLE_STATUS_CODES
170+
)
171+
self.jitter_ratio = jitter_ratio
172+
self.max_backoff_wait = max_backoff_wait
173+
174+
def _calculate_sleep(
175+
self, attempts_made: int, headers: Union[httpx.Headers, Mapping[str, str]]
176+
) -> float:
177+
retry_after_header = (headers.get("Retry-After") or "").strip()
178+
if retry_after_header:
179+
if retry_after_header.isdigit():
180+
return float(retry_after_header)
181+
182+
try:
183+
parsed_date = datetime.fromisoformat(retry_after_header).astimezone()
184+
diff = (parsed_date - datetime.now().astimezone()).total_seconds()
185+
if diff > 0:
186+
return min(diff, self.max_backoff_wait)
187+
except ValueError:
188+
pass
189+
190+
backoff = self.backoff_factor * (2 ** (attempts_made - 1))
191+
jitter = (backoff * self.jitter_ratio) * random.choice([1, -1])
192+
total_backoff = backoff + jitter
193+
return min(total_backoff, self.max_backoff_wait)
194+
195+
def handle_request(self, request: httpx.Request) -> httpx.Response:
196+
response = self._wrapped_transport.handle_request(request) # type: ignore
197+
198+
if request.method not in self.retryable_methods:
199+
return response
200+
201+
remaining_attempts = self.max_attempts - 1
202+
attempts_made = 1
203+
204+
while True:
205+
if (
206+
remaining_attempts < 1
207+
or response.status_code not in self.retry_status_codes
208+
):
209+
return response
210+
211+
response.close()
212+
213+
sleep_for = self._calculate_sleep(attempts_made, response.headers)
214+
time.sleep(sleep_for)
215+
216+
response = self._wrapped_transport.handle_request(request) # type: ignore
217+
218+
attempts_made += 1
219+
remaining_attempts -= 1
220+
221+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
222+
response = await self._wrapped_transport.handle_async_request(request) # type: ignore
223+
224+
if request.method not in self.retryable_methods:
225+
return response
226+
227+
remaining_attempts = self.max_attempts - 1
228+
attempts_made = 1
229+
230+
while True:
231+
if (
232+
remaining_attempts < 1
233+
or response.status_code not in self.retry_status_codes
234+
):
235+
return response
236+
237+
response.close()
238+
239+
sleep_for = self._calculate_sleep(attempts_made, response.headers)
240+
time.sleep(sleep_for)
241+
242+
response = await self._wrapped_transport.handle_async_request(request) # type: ignore
243+
244+
attempts_made += 1
245+
remaining_attempts -= 1
246+
247+
async def aclose(self) -> None:
248+
await self._wrapped_transport.aclose() # type: ignore
249+
250+
def close(self) -> None:
251+
self._wrapped_transport.close() # type: ignore

0 commit comments

Comments
 (0)