Skip to content

Commit 7fbbe66

Browse files
authored
Lazily instantiate underlying client (#170)
#147 changed some logic in how the underlying networking client is initialized. Some clients that set the `REPLICATE_API_TOKEN` environment variable after importing the `replicate` package relied on that behavior, and are now getting authentication errors (#169) This PR restores some of the original behavior by lazily instantiating the underlying client until the first request is made. This should resolve the regression observed by users who were modifying the environment after import. --------- Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 30e4be5 commit 7fbbe66

File tree

2 files changed

+50
-24
lines changed

2 files changed

+50
-24
lines changed

replicate/client.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
class Client:
2727
"""A Replicate API client library"""
2828

29+
__client: Optional[httpx.Client] = None
30+
2931
def __init__(
3032
self,
3133
api_token: Optional[str] = None,
@@ -36,37 +38,41 @@ def __init__(
3638
) -> None:
3739
super().__init__()
3840

39-
api_token = api_token or os.environ.get("REPLICATE_API_TOKEN")
40-
41-
base_url = base_url or os.environ.get(
42-
"REPLICATE_API_BASE_URL", "https://api.replicate.com"
41+
self._api_token = api_token
42+
self._base_url = (
43+
base_url
44+
or os.environ.get("REPLICATE_API_BASE_URL")
45+
or "https://api.replicate.com"
4346
)
44-
45-
timeout = timeout or httpx.Timeout(
47+
self._timeout = timeout or httpx.Timeout(
4648
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
4749
)
50+
self._transport = kwargs.pop("transport", httpx.HTTPTransport())
51+
self._client_kwargs = kwargs
4852

4953
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
5054

51-
headers = {
52-
"User-Agent": f"replicate-python/{__version__}",
53-
}
54-
55-
if api_token is not None and api_token != "":
56-
headers["Authorization"] = f"Token {api_token}"
57-
58-
transport = kwargs.pop("transport", httpx.HTTPTransport())
59-
60-
self._client = self._build_client(
61-
**kwargs,
62-
base_url=base_url,
63-
headers=headers,
64-
timeout=timeout,
65-
transport=RetryTransport(wrapped_transport=transport),
66-
)
55+
@property
56+
def _client(self) -> httpx.Client:
57+
if self.__client is None:
58+
headers = {
59+
"User-Agent": f"replicate-python/{__version__}",
60+
}
61+
62+
api_token = self._api_token or os.environ.get("REPLICATE_API_TOKEN")
63+
64+
if api_token is not None and api_token != "":
65+
headers["Authorization"] = f"Token {api_token}"
66+
67+
self.__client = httpx.Client(
68+
**self._client_kwargs,
69+
base_url=self._base_url,
70+
headers=headers,
71+
timeout=self._timeout,
72+
transport=RetryTransport(wrapped_transport=self._transport),
73+
)
6774

68-
def _build_client(self, **kwargs) -> httpx.Client:
69-
return httpx.Client(**kwargs)
75+
return self.__client
7076

7177
def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
7278
resp = self._client.request(method, path, **kwargs)

tests/test_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
from unittest import mock
3+
4+
import httpx
5+
import pytest
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_authorization_when_setting_environ_after_import():
10+
import replicate
11+
12+
token = "test-set-after-import" # noqa: S105
13+
14+
with mock.patch.dict(
15+
os.environ,
16+
{"REPLICATE_API_TOKEN": token},
17+
):
18+
client: httpx.Client = replicate.default_client._client
19+
assert "Authorization" in client.headers
20+
assert client.headers["Authorization"] == f"Token {token}"

0 commit comments

Comments
 (0)