Skip to content

Commit

Permalink
Add py 429 retries (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Feb 23, 2024
1 parent 5d95f8a commit fc65996
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
31 changes: 30 additions & 1 deletion python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def request_with_retries(
stop_after_attempt: int = 1,
retry_on: Optional[Sequence[Type[BaseException]]] = None,
to_ignore: Optional[Sequence[Type[BaseException]]] = None,
handle_response: Optional[Callable[[requests.Response, int], Any]] = None,
) -> requests.Response:
"""Send a request with retries.
Expand All @@ -540,6 +541,9 @@ def request_with_retries(
[LangSmithConnectionError, LangSmithAPIError].
to_ignore : Sequence[Type[BaseException]] or None, default=None
The exceptions to ignore / pass on.
handle_response : Callable[[requests.Response, int], Any] or None, default=None
A function to handle the response and return whether to continue
retrying.
Returns:
-------
Expand Down Expand Up @@ -567,6 +571,7 @@ def request_with_retries(
)
to_ignore_: Tuple[Type[BaseException], ...] = (*(to_ignore or ()),)
response = None

for idx in range(stop_after_attempt):
try:
try:
Expand All @@ -578,6 +583,11 @@ def request_with_retries(
return response
except requests.HTTPError as e:
if response is not None:
if handle_response is not None:
if idx + 1 < stop_after_attempt:
should_continue = handle_response(response, idx + 1)
if should_continue:
continue
if response.status_code == 500:
raise ls_utils.LangSmithAPIError(
f"Server error caused failure to {request_method}"
Expand Down Expand Up @@ -1085,7 +1095,24 @@ def batch_ingest_runs(
return

self._insert_runtime_env(body["post"])
logger.debug(f"Batch ingesting {len(body['post'])}, {len(body['patch'])} runs")

def handle_429(response: requests.Response, attempt: int) -> bool:
# Min of 30 seconds, max of 1 minute
if response.status_code == 429:
try:
retry_after = float(response.headers.get("retry-after", "30"))
except ValueError:
logger.warning(
"Invalid retry-after header value: %s",
response.headers.get("retry-after"),
)
retry_after = 30
# Add exponential backoff
retry_after = retry_after * 2 ** (attempt - 1) + random.random()
time.sleep(retry_after)
return True
return False

try:
self.request_with_retries(
"post",
Expand All @@ -1100,6 +1127,8 @@ def batch_ingest_runs(
},
},
to_ignore=(ls_utils.LangSmithConflictError,),
stop_after_attempt=3,
handle_response=handle_429,
)
except Exception as e:
logger.warning(f"Failed to batch ingest runs: {repr(e)}")
Expand Down
4 changes: 3 additions & 1 deletion python/tests/integration_tests/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,9 @@ async def my_async_generator(num: int) -> AsyncGenerator[str, None]:
"Async yielded 4",
]

poll_runs_until_count(langchain_client, project_name, 1, max_retries=20)
poll_runs_until_count(
langchain_client, project_name, 1, max_retries=20, sleep_time=5
)
runs = list(langchain_client.list_runs(project_name=project_name))
run = runs[0]
assert run.run_type == "chain"
Expand Down
29 changes: 29 additions & 0 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,3 +722,32 @@ def test_http_status_404_handling(mock_raise_for_status):
mock_raise_for_status.side_effect = HTTPError()
with pytest.raises(ls_utils.LangSmithNotFoundError):
client.request_with_retries("GET", "https://test.url", {})


@patch("langsmith.client.ls_utils.raise_for_status_with_text")
def test_batch_ingest_run_retry_on_429(mock_raise_for_status):
mock_session = MagicMock()
client = Client(api_key="test", session=mock_session)
mock_response = MagicMock()
mock_response.headers = {"retry-after": "0.5"}
mock_response.status_code = 429
mock_session.request.return_value = mock_response
mock_raise_for_status.side_effect = HTTPError()

client.batch_ingest_runs(
create=[
{
"name": "test",
"id": str(uuid.uuid4()),
"trace_id": str(uuid.uuid4()),
"dotted_order": str(uuid.uuid4()),
}
],
)
# Check that there were 3 post calls (may be other get calls though)
assert mock_session.request.call_count >= 3
# count the number of POST requests
assert (
sum([1 for call in mock_session.request.call_args_list if call[0][0] == "post"])
== 3
)

0 comments on commit fc65996

Please sign in to comment.