Skip to content

Include prediction id upload request #1667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions python/cog/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import mimetypes
import os
from typing import Optional
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -39,7 +40,7 @@ def guess_filename(obj: io.IOBase) -> str:


def put_file_to_signed_endpoint(
fh: io.IOBase, endpoint: str, client: requests.Session
fh: io.IOBase, endpoint: str, client: requests.Session, prediction_id: Optional[str]
) -> str:
fh.seek(0)

Expand All @@ -51,18 +52,28 @@ def put_file_to_signed_endpoint(
connect_timeout = 10
read_timeout = 15

headers = {
"Content-Type": content_type,
}
if prediction_id is not None:
headers["X-Prediction-ID"] = prediction_id

resp = client.put(
ensure_trailing_slash(endpoint) + filename,
fh, # type: ignore
headers={"Content-type": content_type},
headers=headers,
timeout=(connect_timeout, read_timeout),
)
resp.raise_for_status()

# strip any signing gubbins from the URL
final_url = urlparse(resp.url)._replace(query="").geturl()
# Try to extract the final asset URL from the `Location` header
# otherwise fallback to the URL of the final request.
final_url = resp.url
if "location" in resp.headers:
final_url = resp.headers.get("location")

return final_url
# strip any signing gubbins from the URL
return str(urlparse(final_url)._replace(query="").geturl())


def ensure_trailing_slash(url: str) -> str:
Expand Down
10 changes: 7 additions & 3 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def create_event_handler(

file_uploader = None
if upload_url is not None:
file_uploader = generate_file_uploader(upload_url)
file_uploader = generate_file_uploader(upload_url, prediction_id=prediction.id)

event_handler = PredictionEventHandler(
response, webhook_sender=webhook_sender, file_uploader=file_uploader
Expand All @@ -202,12 +202,16 @@ def create_event_handler(
return event_handler


def generate_file_uploader(upload_url: str) -> Callable[[Any], Any]:
def generate_file_uploader(
upload_url: str, prediction_id: Optional[str]
) -> Callable[[Any], Any]:
client = _make_file_upload_http_client()

def file_uploader(output: Any) -> Any:
def upload_file(fh: io.IOBase) -> str:
return put_file_to_signed_endpoint(fh, upload_url, client=client)
return put_file_to_signed_endpoint(
fh, endpoint=upload_url, prediction_id=prediction_id, client=client
)

return upload_files(output, upload_file=upload_file)

Expand Down
93 changes: 93 additions & 0 deletions python/tests/cog/test_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import requests
import io
import responses
from cog.files import put_file_to_signed_endpoint
from unittest.mock import Mock


def test_put_file_to_signed_endpoint():
mock_fh = io.BytesIO()
mock_client = Mock()

mock_response = Mock(spec=requests.Response)
mock_response.status_code = 201
mock_response.text = ""
mock_response.headers = {}
mock_response.url = "http://example.com/upload/file?some-gubbins"
mock_response.ok = True

mock_client.put.return_value = mock_response

final_url = put_file_to_signed_endpoint(
mock_fh, "http://example.com/upload", mock_client, prediction_id=None
)

assert final_url == "http://example.com/upload/file"
mock_client.put.assert_called_with(
"http://example.com/upload/file",
mock_fh,
headers={
"Content-Type": None,
},
timeout=(10, 15),
)


def test_put_file_to_signed_endpoint_with_prediction_id():
mock_fh = io.BytesIO()
mock_client = Mock()

mock_response = Mock(spec=requests.Response)
mock_response.status_code = 201
mock_response.text = ""
mock_response.headers = {}
mock_response.url = "http://example.com/upload/file?some-gubbins"
mock_response.ok = True

mock_client.put.return_value = mock_response

final_url = put_file_to_signed_endpoint(
mock_fh, "http://example.com/upload", mock_client, prediction_id="abc123"
)

assert final_url == "http://example.com/upload/file"
mock_client.put.assert_called_with(
"http://example.com/upload/file",
mock_fh,
headers={
"Content-Type": None,
"X-Prediction-ID": "abc123",
},
timeout=(10, 15),
)


def test_put_file_to_signed_endpoint_with_location():
mock_fh = io.BytesIO()
mock_client = Mock()

mock_response = Mock(spec=requests.Response)
mock_response.status_code = 201
mock_response.text = ""
mock_response.headers = {
"location": "http://cdn.example.com/bucket/file?some-gubbins"
}
mock_response.url = "http://example.com/upload/file?some-gubbins"
mock_response.ok = True

mock_client.put.return_value = mock_response

final_url = put_file_to_signed_endpoint(
mock_fh, "http://example.com/upload", mock_client, prediction_id="abc123"
)

assert final_url == "http://cdn.example.com/bucket/file"
mock_client.put.assert_called_with(
"http://example.com/upload/file",
mock_fh,
headers={
"Content-Type": None,
"X-Prediction-ID": "abc123",
},
timeout=(10, 15),
)