Skip to content

[async] Support URLFile in the upload_file function #1987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,19 @@ def __init__(self, fh: io.IOBase) -> None:
self.fh = fh

async def __aiter__(self) -> AsyncIterator[bytes]:
self.fh.seek(0)
if self.fh.seekable():
self.fh.seek(0)

while True:
chunk = self.fh.read(1024 * 1024)

if isinstance(chunk, str):
chunk = chunk.encode("utf-8")

if not chunk:
log.info("finished reading file")
break

yield chunk


Expand Down Expand Up @@ -288,7 +293,8 @@ async def upload_files(
with obj.open("rb") as f:
return await self.upload_file(f, url=url, prediction_id=prediction_id)
if isinstance(obj, io.IOBase):
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
with obj:
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
return obj

# inputs
Expand Down
27 changes: 14 additions & 13 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, TypeVar, Union

import httpx
import requests
from pydantic import Field, SecretStr

FILENAME_ILLEGAL_CHARS = set("\u0000/")
Expand Down Expand Up @@ -195,22 +194,18 @@ def unlink(self, missing_ok: bool = False) -> None:
raise


# we would prefer URLFile to stay lazy
# except... that doesn't really work with httpx?


class URLFile(io.IOBase):
"""
URLFile is a proxy object for a :class:`urllib3.response.HTTPResponse`
object that is created lazily. It's a file-like object constructed from a
URL that can survive pickling/unpickling.

This is the only place Cog uses requests
"""

__slots__ = ("__target__", "__url__")

def __init__(self, url: str) -> None:
parsed = urllib.parse.urlparse(url)
object.__setattr__(self, "name", os.path.basename(parsed.path))
object.__setattr__(self, "__url__", url)

# We provide __getstate__ and __setstate__ explicitly to ensure that the
Expand Down Expand Up @@ -242,19 +237,25 @@ def __delattr__(self, name: str) -> None:

# Luckily the only dunder method on HTTPResponse is __iter__
def __iter__(self) -> Iterator[bytes]:
return iter(self.__wrapped__)
response = self.__wrapped__
return iter(response)

@property
def __wrapped__(self) -> Any:
try:
return object.__getattribute__(self, "__target__")
except AttributeError:
url = object.__getattribute__(self, "__url__")
resp = requests.get(url, stream=True)
resp.raise_for_status()
resp.raw.decode_content = True
object.__setattr__(self, "__target__", resp.raw)
return resp.raw

# We create a streaming response here, much like the `requests`
# version in the main 0.9.x branch. The only concerning bit here
# is that the book keeping for closing the response needs to be
# handled elsewhere. There's probably a better design for this
# in the long term.
res = urllib.request.urlopen(url) # noqa: S310
object.__setattr__(self, "__target__", res)

return res

def __repr__(self) -> str:
try:
Expand Down
68 changes: 63 additions & 5 deletions python/tests/server/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import httpx
from email.message import Message
import io
import os
import responses
import tempfile
from urllib.response import addinfourl
from unittest import mock

import cog
import httpx
import pytest
from cog.server.clients import ClientManager

pytest.mark.asyncio


@pytest.mark.asyncio
async def test_upload_files_without_url():
Expand Down Expand Up @@ -103,9 +108,62 @@ async def test_upload_files_with_retry(respx_mock):

obj = {"path": cog.Path(temp_path)}
with pytest.raises(httpx.HTTPStatusError):
result = await client_manager.upload_files(
await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)

assert uploader.call_count == 3


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
@mock.patch("urllib.request.urlopen")
async def test_upload_files_with_url_file(urlopen_mock, respx_mock):
fp = io.BytesIO(b"hello world")
urlopen_mock.return_value = addinfourl(
fp=fp, headers=Message(), url="https://example.com/cdn/my_file.txt"
)

uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(
201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"}
)
)

client_manager = ClientManager()

obj = {"path": cog.types.URLFile("https://example.com/cdn/my_file.txt")}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}

assert uploader.call_count == 1
assert urlopen_mock.call_count == 1
assert urlopen_mock.call_args[0][0] == "https://example.com/cdn/my_file.txt"


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
@mock.patch("urllib.request.urlopen")
async def test_upload_files_with_url_file_with_retry(urlopen_mock, respx_mock):
fp = io.BytesIO(b"hello world")
urlopen_mock.return_value = addinfourl(
fp=fp, headers=Message(), url="https://example.com/cdn/my_file.txt"
)

uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(502)
)

client_manager = ClientManager()

obj = {"path": cog.types.URLFile("https://example.com/cdn/my_file.txt")}
with pytest.raises(httpx.HTTPStatusError):
await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)

assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
assert uploader.call_count == 3
assert uploader.call_count == 3
assert urlopen_mock.call_count == 1
assert urlopen_mock.call_args[0][0] == "https://example.com/cdn/my_file.txt"
48 changes: 23 additions & 25 deletions python/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
import io
import pickle
import urllib.request
from urllib.response import addinfourl
from email.message import Message
from unittest import mock

import pytest
import responses
from cog.types import Secret, URLFile, get_filename_from_url, get_filename_from_urlopen


@responses.activate
def test_urlfile_acts_like_response():
responses.get(
"https://example.com/some/url",
json={"message": "hello world"},
status=200,
# Represents a response from urllib.request.urlopen
def file_fixture(body: str):
return addinfourl(
fp=io.BytesIO(bytes(body, "utf-8")),
headers=Message(),
url="https://example.com/cdn/my_file.txt",
)


@mock.patch("urllib.request.urlopen", return_value=file_fixture("hello world"))
def test_urlfile_acts_like_response(mock_urlopen):
u = URLFile("https://example.com/some/url")

assert isinstance(u, io.IOBase)
assert u.read() == b'{"message": "hello world"}'

assert u.read() == b"hello world"
assert mock_urlopen.call_count == 1

@responses.activate
def test_urlfile_iterable():
responses.get(
"https://example.com/some/url",
body="one\ntwo\nthree\n",
status=200,
)

@mock.patch("urllib.request.urlopen", return_value=file_fixture("one\ntwo\nthree\n"))
def test_urlfile_iterable(mock_urlopen):
u = URLFile("https://example.com/some/url")
result = list(u)

assert result == [b"one\n", b"two\n", b"three\n"]
assert mock_urlopen.call_count == 1


@responses.activate
Expand All @@ -42,29 +44,25 @@ def test_urlfile_no_request_if_not_used():
URLFile("https://example.com/some/url")


@responses.activate
def test_urlfile_can_be_pickled():
@mock.patch("urllib.request.urlopen", return_value=file_fixture("hello world"))
def test_urlfile_can_be_pickled(mock_urlopen):
u = URLFile("https://example.com/some/url")

result = pickle.loads(pickle.dumps(u))

assert isinstance(result, URLFile)
assert mock_urlopen.call_count == 0


@responses.activate
def test_urlfile_can_be_pickled_even_once_loaded():
responses.get(
"https://example.com/some/url",
json={"message": "hello world"},
status=200,
)

@mock.patch("urllib.request.urlopen", return_value=file_fixture("hello world"))
def test_urlfile_can_be_pickled_even_once_loaded(mock_urlopen):
u = URLFile("https://example.com/some/url")
u.read()

result = pickle.loads(pickle.dumps(u))

assert isinstance(result, URLFile)
assert mock_urlopen.call_count == 1


@pytest.mark.parametrize(
Expand Down
Loading