Skip to content

Commit 1e3b2da

Browse files
committed
Update client to use httpx
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent bbafa98 commit 1e3b2da

File tree

3 files changed

+49
-106
lines changed

3 files changed

+49
-106
lines changed

replicate/client.py

Lines changed: 43 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,63 @@
11
import os
22
import re
3-
from json import JSONDecodeError
4-
from typing import Any, Dict, Iterator, Optional, Union
3+
from typing import Any, Iterator, Optional, Union
54

6-
import requests
7-
from requests.adapters import HTTPAdapter, Retry
8-
from requests.cookies import RequestsCookieJar
5+
import httpx
96

10-
from replicate.__about__ import __version__
11-
from replicate.exceptions import ModelError, ReplicateError
12-
from replicate.model import ModelCollection
13-
from replicate.prediction import PredictionCollection
14-
from replicate.training import TrainingCollection
7+
from .__about__ import __version__
8+
from .exceptions import ModelError, ReplicateError
9+
from .model import ModelCollection
10+
from .prediction import PredictionCollection
11+
from .training import TrainingCollection
1512

1613

1714
class Client:
18-
def __init__(self, api_token: Optional[str] = None) -> None:
15+
"""A Replicate API client library"""
16+
17+
def __init__(
18+
self,
19+
api_token: Optional[str] = None,
20+
*,
21+
base_url: Optional[str] = None,
22+
timeout: Optional[httpx.Timeout] = None,
23+
**kwargs,
24+
) -> None:
1925
super().__init__()
20-
# Client is instantiated at import time, so do as little as possible.
21-
# This includes resolving environment variables -- they might be set programmatically.
22-
self.api_token = api_token
23-
self.base_url = os.environ.get(
26+
27+
api_token = api_token or os.environ.get("REPLICATE_API_TOKEN")
28+
29+
base_url = base_url or os.environ.get(
2430
"REPLICATE_API_BASE_URL", "https://api.replicate.com"
2531
)
26-
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
2732

28-
# TODO: make thread safe
29-
self.read_session = _create_session()
30-
read_retries = Retry(
31-
total=5,
32-
backoff_factor=2,
33-
# Only retry 500s on GET so we don't unintionally mutute data
34-
allowed_methods=["GET"],
35-
# https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors
36-
status_forcelist=[
37-
429,
38-
500,
39-
502,
40-
503,
41-
504,
42-
520,
43-
521,
44-
522,
45-
523,
46-
524,
47-
526,
48-
527,
49-
],
33+
timeout = timeout or httpx.Timeout(
34+
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
5035
)
51-
self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries))
52-
self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries))
53-
54-
self.write_session = _create_session()
55-
write_retries = Retry(
56-
total=5,
57-
backoff_factor=2,
58-
allowed_methods=["POST", "PUT"],
59-
# Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data
60-
status_forcelist=[429],
61-
)
62-
self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries))
63-
self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries))
64-
65-
def _request(self, method: str, path: str, **kwargs) -> requests.Response:
66-
# from requests.Session
67-
if method in ["GET", "OPTIONS"]:
68-
kwargs.setdefault("allow_redirects", True)
69-
if method in ["HEAD"]:
70-
kwargs.setdefault("allow_redirects", False)
71-
kwargs.setdefault("headers", {})
72-
kwargs["headers"].update(self._headers())
73-
session = self.read_session
74-
if method in ["POST", "PUT", "DELETE", "PATCH"]:
75-
session = self.write_session
76-
resp = session.request(method, self.base_url + path, **kwargs)
77-
if 400 <= resp.status_code < 600:
78-
try:
79-
raise ReplicateError(resp.json()["detail"])
80-
except (JSONDecodeError, KeyError):
81-
pass
82-
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
83-
return resp
8436

85-
def _headers(self) -> Dict[str, str]:
86-
return {
87-
"Authorization": f"Token {self._api_token()}",
37+
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
38+
39+
headers = {
40+
"Authorization": f"Token {api_token}",
8841
"User-Agent": f"replicate-python/{__version__}",
8942
}
9043

91-
def _api_token(self) -> str:
92-
token = self.api_token
93-
# Evaluate lazily in case environment variable is set with dotenv, or something
94-
if token is None:
95-
token = os.environ.get("REPLICATE_API_TOKEN")
96-
if not token:
97-
raise ReplicateError(
98-
"""No API token provided. You need to set the REPLICATE_API_TOKEN environment variable or create a client with `replicate.Client(api_token=...)`.
44+
self._client = self._build_client(
45+
**kwargs,
46+
base_url=base_url,
47+
headers=headers,
48+
timeout=timeout,
49+
)
9950

100-
You can find your API key on https://replicate.com"""
101-
)
102-
return token
51+
def _build_client(self, **kwargs) -> httpx.Client:
52+
return httpx.Client(**kwargs)
53+
54+
def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
55+
resp = self._client.request(method, path, **kwargs)
56+
57+
if 400 <= resp.status_code < 600:
58+
raise ReplicateError(resp.json()["detail"])
59+
60+
return resp
10361

10462
@property
10563
def models(self) -> ModelCollection:
@@ -145,21 +103,3 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
145103
if prediction.status == "failed":
146104
raise ModelError(prediction.error)
147105
return prediction.output
148-
149-
150-
class _NonpersistentCookieJar(RequestsCookieJar):
151-
"""
152-
A cookie jar that doesn't persist cookies between requests.
153-
"""
154-
155-
def set(self, name, value, **kwargs) -> None:
156-
return
157-
158-
def set_cookie(self, cookie, *args, **kwargs) -> None:
159-
return
160-
161-
162-
def _create_session() -> requests.Session:
163-
s = requests.Session()
164-
s.cookies = _NonpersistentCookieJar()
165-
return s

replicate/files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from typing import Optional
66

7-
import requests
7+
import httpx
88

99

1010
def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
@@ -24,7 +24,7 @@ def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
2424
if output_file_prefix is not None:
2525
name = getattr(fh, "name", "output")
2626
url = output_file_prefix + os.path.basename(name)
27-
resp = requests.put(url, files={"file": fh}, timeout=None)
27+
resp = httpx.put(url, files={"file": fh}, timeout=None) # type: ignore
2828
resp.raise_for_status()
2929
return url
3030

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ def mock_replicate_api_token(scope="class"):
99
if os.environ.get("REPLICATE_API_TOKEN", "") != "":
1010
yield
1111
else:
12-
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "test-token"}):
12+
with mock.patch.dict(
13+
os.environ,
14+
{"REPLICATE_API_TOKEN": "test-token", "REPLICATE_POLL_INTERVAL": "0.0"},
15+
):
1316
yield
1417

1518

0 commit comments

Comments
 (0)