Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions core/bus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@

import os
import asyncio
import inspect
import logging
from uuid import uuid4
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, cast
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar

import httpx

Expand Down Expand Up @@ -54,20 +53,23 @@ def __init__(

async def _with_retry(
self,
func: Callable[[], Awaitable[T] | T],
func: Callable[[], Awaitable[T]],
*,
label: str,
retries: int,
backoff: float,
exc_type: type[Exception],
) -> T | None:
) -> T:
"""Execute ``func`` with retries and return its result.

Raises the last encountered exception after exhausting retries.
"""
last_exc: Exception | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve robustness, consider handling negative values for retries. A negative number of retries is not meaningful and currently leads to an AssertionError because the retry loop is skipped. It would be safer to treat negative values as 0, ensuring the function behaves predictably even with invalid input.

        if retries < 0:
            retries = 0
        last_exc: Exception | None = None

for attempt in range(retries + 1):
try:
result = func()
if inspect.isawaitable(result):
return await result
return result
return await func()
except exc_type as exc: # pragma: no cover - logging path
last_exc = exc
logger.exception("%s failed", label)
add_entry(kind="bus_client_error", data=f"{label} failed: {exc}")
if self.circuit_breaker and self.circuit_breaker():
Expand All @@ -78,7 +80,8 @@ async def _with_retry(
delay = self.jitter(delay)
if delay:
await asyncio.sleep(delay)
return None
assert last_exc is not None
raise last_exc

async def _request(
self,
Expand All @@ -88,7 +91,7 @@ async def _request(
retries: Optional[int] = None,
backoff: Optional[float] = None,
**kwargs: Any,
) -> httpx.Response | None:
) -> httpx.Response:
return await self._arequest(
method,
endpoint,
Expand All @@ -99,16 +102,20 @@ async def _request(

async def run(self) -> None:
"""Process a single message from the bus."""
r = await self._request(
"get",
"get",
params={
"topic": self.topic,
"group": self.client_id,
"consumer": self.client_id,
},
)
if r and r.status_code == 200:
try:
r = await self._request(
"get",
"get",
params={
"topic": self.topic,
"group": self.client_id,
"consumer": self.client_id,
},
)
except httpx.RequestError:
await asyncio.sleep(1)
return
if r.status_code == 200:
msg = r.json()
data = msg.get("data", {})
text = data.get("text")
Expand Down Expand Up @@ -156,7 +163,7 @@ async def _arequest(
retries: Optional[int] = None,
backoff: Optional[float] = None,
**kwargs: Any,
) -> httpx.Response | None:
) -> httpx.Response:
retries = self.retries if retries is None else retries
backoff = self.backoff if backoff is None else backoff
url = f"{self.base_url}/{endpoint}"
Expand All @@ -167,14 +174,13 @@ async def _arequest(
async def call() -> httpx.Response:
return await self._client.request(method, url, headers=headers, **kwargs)

result = await self._with_retry(
return await self._with_retry(
call,
label=f"{method.upper()} {url}",
retries=retries,
backoff=backoff,
exc_type=httpx.RequestError,
)
return cast(httpx.Response | None, result)

async def publish(
self,
Expand Down
Loading