Skip to content

Commit 29739c5

Browse files
aronmattt
authored andcommitted
[async] Include prediction id upload request (#1788)
* Cast TraceContext into Mapping[str, str] to fix linter * Include prediction id upload request Based on #1667 This PR introduces two small changes to the file upload interface. 1. We now allow downstream services to include the destination of the asset in a `Location` header, rather than assuming that it's the same as the final upload url (either the one passed via `--upload-url` or the result of a 307 redirect response. 2. We now include the `X-Prediction-Id` header in upload request, this allows the downstream client to potentially do configuration/routing based on the prediction ID. This ID should be considered unsafe and needs to be validated by the downstream service. * Extract ChunkFileReader into top-level class --------- Co-authored-by: Mattt Zmuda <mattt@replicate.com>
1 parent d33d106 commit 29739c5

File tree

3 files changed

+158
-28
lines changed

3 files changed

+158
-28
lines changed

python/cog/server/clients.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22
import io
33
import mimetypes
44
import os
5-
from typing import Any, AsyncIterator, Awaitable, Callable, Collection, Dict, Optional
5+
from typing import (
6+
Any,
7+
AsyncIterator,
8+
Awaitable,
9+
Callable,
10+
Collection,
11+
Dict,
12+
Mapping,
13+
Optional,
14+
cast,
15+
)
616
from urllib.parse import urlparse
717

818
import httpx
@@ -59,7 +69,7 @@ def webhook_headers() -> "dict[str, str]":
5969

6070
async def on_request_trace_context_hook(request: httpx.Request) -> None:
6171
ctx = current_trace_context() or {}
62-
request.headers.update(ctx)
72+
request.headers.update(cast(Mapping[str, str], ctx))
6373

6474

6575
def httpx_webhook_client() -> httpx.AsyncClient:
@@ -109,6 +119,22 @@ def httpx_file_client() -> httpx.AsyncClient:
109119
)
110120

111121

122+
class ChunkFileReader:
123+
def __init__(self, fh: io.IOBase) -> None:
124+
self.fh = fh
125+
126+
async def __aiter__(self) -> AsyncIterator[bytes]:
127+
self.fh.seek(0)
128+
while True:
129+
chunk = self.fh.read(1024 * 1024)
130+
if isinstance(chunk, str):
131+
chunk = chunk.encode("utf-8")
132+
if not chunk:
133+
log.info("finished reading file")
134+
break
135+
yield chunk
136+
137+
112138
# there's a case for splitting this apart or inlining parts of it
113139
# I'm somewhat sympathetic to separating webhooks and files, but they both have
114140
# the same semantics of holding a client for the lifetime of runner
@@ -159,10 +185,11 @@ async def sender(response: Any, event: WebhookEvent) -> None:
159185

160186
# files
161187

162-
async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
188+
async def upload_file(
189+
self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str]
190+
) -> str:
163191
"""put file to signed endpoint"""
164192
log.debug("upload_file")
165-
fh.seek(0)
166193
# try to guess the filename of the given object
167194
name = getattr(fh, "name", "file")
168195
filename = os.path.basename(name) or "file"
@@ -180,17 +207,12 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
180207
# ensure trailing slash
181208
url_with_trailing_slash = url if url.endswith("/") else url + "/"
182209

183-
async def chunk_file_reader() -> AsyncIterator[bytes]:
184-
while 1:
185-
chunk = fh.read(1024 * 1024)
186-
if isinstance(chunk, str):
187-
chunk = chunk.encode("utf-8")
188-
if not chunk:
189-
log.info("finished reading file")
190-
break
191-
yield chunk
192-
193210
url = url_with_trailing_slash + filename
211+
212+
headers = {"Content-Type": content_type}
213+
if prediction_id is not None:
214+
headers["X-Prediction-ID"] = prediction_id
215+
194216
# this is a somewhat unfortunate hack, but it works
195217
# and is critical for upload training/quantization outputs
196218
# if we get multipart uploads working or a separate API route
@@ -200,29 +222,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
200222
resp1 = await self.file_client.put(
201223
url,
202224
content=b"",
203-
headers={"Content-Type": content_type},
225+
headers=headers,
204226
follow_redirects=False,
205227
)
206228
if resp1.status_code == 307 and resp1.headers["Location"]:
207229
log.info("got file upload redirect from api")
208230
url = resp1.headers["Location"]
231+
209232
log.info("doing real upload to %s", url)
210233
resp = await self.file_client.put(
211234
url,
212-
content=chunk_file_reader(),
213-
headers={"Content-Type": content_type},
235+
content=ChunkFileReader(fh),
236+
headers=headers,
214237
)
215238
# TODO: if file size is >1MB, show upload throughput
216239
resp.raise_for_status()
217240

218-
# strip any signing gubbins from the URL
219-
final_url = urlparse(str(resp.url))._replace(query="").geturl()
241+
# Try to extract the final asset URL from the `Location` header
242+
# otherwise fallback to the URL of the final request.
243+
final_url = str(resp.url)
244+
if "location" in resp.headers:
245+
final_url = resp.headers.get("location")
220246

221-
return final_url
247+
# strip any signing gubbins from the URL
248+
return urlparse(final_url)._replace(query="").geturl()
222249

223250
# this previously lived in json.upload_files, but it's clearer here
224251
# this is a great pattern that should be adopted for input files
225-
async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
252+
async def upload_files(
253+
self, obj: Any, *, url: Optional[str], prediction_id: Optional[str]
254+
) -> Any:
226255
"""
227256
Iterates through an object from make_encodeable and uploads any files.
228257
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
@@ -234,15 +263,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
234263
# TODO: upload concurrently
235264
if isinstance(obj, dict):
236265
return {
237-
key: await self.upload_files(value, url) for key, value in obj.items()
266+
key: await self.upload_files(
267+
value, url=url, prediction_id=prediction_id
268+
)
269+
for key, value in obj.items()
238270
}
239271
if isinstance(obj, list):
240-
return [await self.upload_files(value, url) for value in obj]
272+
return [
273+
await self.upload_files(value, url=url, prediction_id=prediction_id)
274+
for value in obj
275+
]
241276
if isinstance(obj, Path):
242277
with obj.open("rb") as f:
243-
return await self.upload_file(f, url)
278+
return await self.upload_file(f, url=url, prediction_id=prediction_id)
244279
if isinstance(obj, io.IOBase):
245-
return await self.upload_file(obj, url)
280+
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
246281
return obj
247282

248283
# we could also handle inputs here, with a convert_prediction_input function

python/cog/server/runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ async def _send_webhook(self, event: schema.WebhookEvent) -> None:
291291
async def _upload_files(self, output: Any) -> Any:
292292
try:
293293
# TODO: clean up output files
294-
return await self._client_manager.upload_files(output, self._upload_url)
294+
return await self._client_manager.upload_files(
295+
output, url=self._upload_url, prediction_id=self.p.id
296+
)
295297
except Exception as error:
296298
# If something goes wrong uploading a file, it's irrecoverable.
297299
# The re-raised exception will be caught and cause the prediction

python/tests/server/test_clients.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import httpx
12
import os
3+
import responses
24
import tempfile
35

46
import cog
@@ -7,12 +9,103 @@
79

810

911
@pytest.mark.asyncio
10-
async def test_upload_files():
12+
async def test_upload_files_without_url():
1113
client_manager = ClientManager()
1214
temp_dir = tempfile.mkdtemp()
1315
temp_path = os.path.join(temp_dir, "my_file.txt")
1416
with open(temp_path, "w") as fh:
1517
fh.write("file content")
1618
obj = {"path": cog.Path(temp_path)}
17-
result = await client_manager.upload_files(obj, None)
19+
result = await client_manager.upload_files(obj, url=None, prediction_id=None)
1820
assert result == {"path": "data:text/plain;base64,ZmlsZSBjb250ZW50"}
21+
22+
23+
@pytest.mark.asyncio
24+
@pytest.mark.respx(base_url="https://example.com")
25+
async def test_upload_files_with_url(respx_mock):
26+
uploader = respx_mock.put("/bucket/my_file.txt").mock(
27+
return_value=httpx.Response(201)
28+
)
29+
30+
client_manager = ClientManager()
31+
temp_dir = tempfile.mkdtemp()
32+
temp_path = os.path.join(temp_dir, "my_file.txt")
33+
with open(temp_path, "w") as fh:
34+
fh.write("file content")
35+
36+
obj = {"path": cog.Path(temp_path)}
37+
result = await client_manager.upload_files(
38+
obj, url="https://example.com/bucket", prediction_id=None
39+
)
40+
assert result == {"path": "https://example.com/bucket/my_file.txt"}
41+
42+
assert uploader.call_count == 1
43+
44+
45+
@pytest.mark.asyncio
46+
@pytest.mark.respx(base_url="https://example.com")
47+
async def test_upload_files_with_prediction_id(respx_mock):
48+
uploader = respx_mock.put(
49+
"/bucket/my_file.txt", headers={"x-prediction-id": "p123"}
50+
).mock(return_value=httpx.Response(201))
51+
52+
client_manager = ClientManager()
53+
temp_dir = tempfile.mkdtemp()
54+
temp_path = os.path.join(temp_dir, "my_file.txt")
55+
with open(temp_path, "w") as fh:
56+
fh.write("file content")
57+
58+
obj = {"path": cog.Path(temp_path)}
59+
result = await client_manager.upload_files(
60+
obj, url="https://example.com/bucket", prediction_id="p123"
61+
)
62+
assert result == {"path": "https://example.com/bucket/my_file.txt"}
63+
64+
assert uploader.call_count == 1
65+
66+
67+
@pytest.mark.asyncio
68+
@pytest.mark.respx(base_url="https://example.com")
69+
async def test_upload_files_with_location_header(respx_mock):
70+
uploader = respx_mock.put("/bucket/my_file.txt").mock(
71+
return_value=httpx.Response(
72+
201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"}
73+
)
74+
)
75+
76+
client_manager = ClientManager()
77+
temp_dir = tempfile.mkdtemp()
78+
temp_path = os.path.join(temp_dir, "my_file.txt")
79+
with open(temp_path, "w") as fh:
80+
fh.write("file content")
81+
82+
obj = {"path": cog.Path(temp_path)}
83+
result = await client_manager.upload_files(
84+
obj, url="https://example.com/bucket", prediction_id=None
85+
)
86+
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
87+
88+
assert uploader.call_count == 1
89+
90+
91+
@pytest.mark.asyncio
92+
@pytest.mark.respx(base_url="https://example.com")
93+
async def test_upload_files_with_retry(respx_mock):
94+
uploader = respx_mock.put("/bucket/my_file.txt").mock(
95+
return_value=httpx.Response(502)
96+
)
97+
98+
client_manager = ClientManager()
99+
temp_dir = tempfile.mkdtemp()
100+
temp_path = os.path.join(temp_dir, "my_file.txt")
101+
with open(temp_path, "w") as fh:
102+
fh.write("file content")
103+
104+
obj = {"path": cog.Path(temp_path)}
105+
with pytest.raises(httpx.HTTPStatusError):
106+
result = await client_manager.upload_files(
107+
obj, url="https://example.com/bucket", prediction_id=None
108+
)
109+
110+
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
111+
assert uploader.call_count == 3

0 commit comments

Comments
 (0)