Skip to content

[async] Include prediction id upload request #1788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 60 additions & 25 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import io
import mimetypes
import os
from typing import Any, AsyncIterator, Awaitable, Callable, Collection, Dict, Optional
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
Dict,
Mapping,
Optional,
cast,
)
from urllib.parse import urlparse

import httpx
Expand Down Expand Up @@ -62,7 +72,7 @@ def webhook_headers() -> "dict[str, str]":

async def on_request_trace_context_hook(request: httpx.Request) -> None:
ctx = current_trace_context() or {}
request.headers.update(ctx)
request.headers.update(cast(Mapping[str, str], ctx))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes the linting issue on async.



def httpx_webhook_client() -> httpx.AsyncClient:
Expand Down Expand Up @@ -111,6 +121,22 @@ def httpx_file_client() -> httpx.AsyncClient:
)


class ChunkFileReader:
def __init__(self, fh: io.IOBase) -> None:
self.fh = fh

async def __aiter__(self) -> AsyncIterator[bytes]:
self.fh.seek(0)
while True:
chunk = self.fh.read(1024 * 1024)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
if not chunk:
log.info("finished reading file")
break
yield chunk


# there's a case for splitting this apart or inlining parts of it
# I'm somewhat sympathetic to separating webhooks and files, but they both have
# the same semantics of holding a client for the lifetime of runner
Expand Down Expand Up @@ -167,10 +193,11 @@ async def sender(response: PredictionResponse, event: WebhookEvent) -> None:

# files

async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
async def upload_file(
self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str]
) -> str:
"""put file to signed endpoint"""
log.debug("upload_file")
fh.seek(0)
# try to guess the filename of the given object
name = getattr(fh, "name", "file")
filename = os.path.basename(name) or "file"
Expand All @@ -188,17 +215,12 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
# ensure trailing slash
url_with_trailing_slash = url if url.endswith("/") else url + "/"

async def chunk_file_reader() -> AsyncIterator[bytes]:
while 1:
chunk = fh.read(1024 * 1024)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
if not chunk:
log.info("finished reading file")
break
yield chunk

url = url_with_trailing_slash + filename

headers = {"Content-Type": content_type}
if prediction_id is not None:
headers["X-Prediction-ID"] = prediction_id

# this is a somewhat unfortunate hack, but it works
# and is critical for upload training/quantization outputs
# if we get multipart uploads working or a separate API route
Expand All @@ -208,29 +230,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
resp1 = await self.file_client.put(
url,
content=b"",
headers={"Content-Type": content_type},
headers=headers,
follow_redirects=False,
)
if resp1.status_code == 307 and resp1.headers["Location"]:
log.info("got file upload redirect from api")
url = resp1.headers["Location"]

log.info("doing real upload to %s", url)
resp = await self.file_client.put(
url,
content=chunk_file_reader(),
headers={"Content-Type": content_type},
content=ChunkFileReader(fh),
headers=headers,
)
# TODO: if file size is >1MB, show upload throughput
resp.raise_for_status()

# strip any signing gubbins from the URL
final_url = urlparse(str(resp.url))._replace(query="").geturl()
# Try to extract the final asset URL from the `Location` header
# otherwise fallback to the URL of the final request.
final_url = str(resp.url)
if "location" in resp.headers:
final_url = resp.headers.get("location")

return final_url
# strip any signing gubbins from the URL
return urlparse(final_url)._replace(query="").geturl()

# this previously lived in json.upload_files, but it's clearer here
# this is a great pattern that should be adopted for input files
async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
async def upload_files(
self, obj: Any, *, url: Optional[str], prediction_id: Optional[str]
) -> Any:
"""
Iterates through an object from make_encodeable and uploads any files.
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
Expand All @@ -245,15 +274,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
# TODO: upload concurrently
if isinstance(obj, dict):
return {
key: await self.upload_files(value, url) for key, value in obj.items()
key: await self.upload_files(
value, url=url, prediction_id=prediction_id
)
for key, value in obj.items()
}
if isinstance(obj, list):
return [await self.upload_files(value, url) for value in obj]
return [
await self.upload_files(value, url=url, prediction_id=prediction_id)
for value in obj
]
if isinstance(obj, Path):
with obj.open("rb") as f:
return await self.upload_file(f, url)
return await self.upload_file(f, url=url, prediction_id=prediction_id)
if isinstance(obj, io.IOBase):
return await self.upload_file(obj, url)
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
return obj

# inputs
Expand Down
8 changes: 6 additions & 2 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def __init__(
self._shutdown_event = shutdown_event # __main__ waits for this event

self._upload_url = upload_url
self._predictions: dict[str, tuple[schema.PredictionResponse, PredictionTask]] = {}
self._predictions: dict[
str, tuple[schema.PredictionResponse, PredictionTask]
] = {}
self._predictions_in_flight: set[str] = set()
# it would be lovely to merge these but it's not fully clear how best to handle it
# since idempotent requests can kinda come whenever?
Expand Down Expand Up @@ -536,7 +538,9 @@ async def _send_webhook(self, event: schema.WebhookEvent) -> None:
async def _upload_files(self, output: Any) -> Any:
try:
# TODO: clean up output files
return await self._client_manager.upload_files(output, self._upload_url)
return await self._client_manager.upload_files(
output, url=self._upload_url, prediction_id=self.p.id
)
except Exception as error:
# If something goes wrong uploading a file, it's irrecoverable.
# The re-raised exception will be caught and cause the prediction
Expand Down
97 changes: 95 additions & 2 deletions python/tests/server/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import httpx
import os
import responses
import tempfile

import cog
Expand All @@ -7,12 +9,103 @@


@pytest.mark.asyncio
async def test_upload_files():
async def test_upload_files_without_url():
client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")
obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(obj, None)
result = await client_manager.upload_files(obj, url=None, prediction_id=None)
assert result == {"path": "data:text/plain;base64,ZmlsZSBjb250ZW50"}


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_url(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(201)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)
assert result == {"path": "https://example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_prediction_id(respx_mock):
uploader = respx_mock.put(
"/bucket/my_file.txt", headers={"x-prediction-id": "p123"}
).mock(return_value=httpx.Response(201))

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id="p123"
)
assert result == {"path": "https://example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_location_header(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(
201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"}
)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_retry(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(502)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
with pytest.raises(httpx.HTTPStatusError):
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)

assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
assert uploader.call_count == 3