Skip to content

Commit df0637a

Browse files
nickstenningmattt
andcommitted
Don't persist or send cookies to Replicate API
Replicate's API does not use cookies and even if we return cookies the client should not save and replay them. Co-authored-by: Mattt Zmuda <mattt@replicate.com> Signed-off-by: Nick Stenning <nick@whiteink.com> Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 44e10a4 commit df0637a

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

replicate/client.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import requests
77
from requests.adapters import HTTPAdapter, Retry
8+
from requests.cookies import RequestsCookieJar
89

910
from replicate.__about__ import __version__
1011
from replicate.exceptions import ModelError, ReplicateError
@@ -25,7 +26,7 @@ def __init__(self, api_token: Optional[str] = None) -> None:
2526
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
2627

2728
# TODO: make thread safe
28-
self.read_session = requests.Session()
29+
self.read_session = _create_session()
2930
read_retries = Retry(
3031
total=5,
3132
backoff_factor=2,
@@ -50,7 +51,7 @@ def __init__(self, api_token: Optional[str] = None) -> None:
5051
self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries))
5152
self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries))
5253

53-
self.write_session = requests.Session()
54+
self.write_session = _create_session()
5455
write_retries = Retry(
5556
total=5,
5657
backoff_factor=2,
@@ -138,3 +139,21 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
138139
if prediction.status == "failed":
139140
raise ModelError(prediction.error)
140141
return prediction.output
142+
143+
144+
class _NonpersistentCookieJar(RequestsCookieJar):
145+
"""
146+
A cookie jar that doesn't persist cookies between requests.
147+
"""
148+
149+
def set(self, name, value, **kwargs) -> None:
150+
return
151+
152+
def set_cookie(self, cookie, *args, **kwargs) -> None:
153+
return
154+
155+
156+
def _create_session() -> requests.Session:
157+
s = requests.Session()
158+
s.cookies = _NonpersistentCookieJar()
159+
return s

0 commit comments

Comments
 (0)