Skip to content

Commit 6996db3

Browse files
committed
dubious upload fix
Signed-off-by: technillogue <technillogue@gmail.com>
1 parent 7b9179b commit 6996db3

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

python/cog/server/clients.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
156156
# in that case we need to return data uris
157157
if url is None:
158158
return file_to_data_uri(fh, content_type)
159+
assert url
159160

160161
# ensure trailing slash
161162
url_with_trailing_slash = url if url.endswith("/") else url + "/"
@@ -169,10 +170,20 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
169170
break
170171
yield chunk
171172

173+
url = url_with_trailing_slash + filename
174+
if url and "internal" in url:
175+
resp1 = await self.file_client.put(
176+
url,
177+
content=b"",
178+
headers={"Content-Type": content_type},
179+
follow_redirects=False,
180+
)
181+
if resp1.status_code == 307:
182+
url = resp1.headers["Location"]
172183
resp = await self.file_client.put(
173-
url_with_trailing_slash + filename,
184+
url,
174185
content=chunk_file_reader(),
175-
headers={"Content-type": content_type},
186+
headers={"Content-Type": content_type},
176187
)
177188
resp.raise_for_status()
178189

python/cog/server/runner.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import typing # TypeAlias, py3.10
1111
from datetime import datetime, timezone
1212
from enum import Enum, auto, unique
13-
from typing import Any, AsyncIterator, Iterator, Optional, Union, TypeVar
13+
from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union
1414

1515
import httpx
1616
import structlog
@@ -483,13 +483,16 @@ async def handle_event_stream(
483483
break
484484
return self.response
485485

486+
async def noop(self) -> None:
487+
pass
488+
486489
def event_to_handle_future(self, event: PublicEventType) -> Awaitable[None]:
487490
if isinstance(event, Heartbeat):
488491
# Heartbeat events exist solely to ensure that we have a
489492
# regular opportunity to check for cancelation and
490493
# timeouts.
491494
# We don't need to do anything with them.
492-
return
495+
return self.noop()
493496
if isinstance(event, Log):
494497
return self.append_logs(event.message)
495498

@@ -500,9 +503,9 @@ def event_to_handle_future(self, event: PublicEventType) -> Awaitable[None]:
500503
if self._output_type.multi:
501504
return self.set_output([])
502505
if isinstance(event, PredictionOutput):
503-
if output_type is None:
506+
if self._output_type is None:
504507
return self.failed(error="Predictor returned unexpected output")
505-
if output_type.multi:
508+
if self._output_type.multi:
506509
return self.append_output(event.payload)
507510
return self.set_output(event.payload)
508511
if isinstance(event, Done): # pyright: ignore reportUnnecessaryIsinstance

0 commit comments

Comments
 (0)