Skip to content

Commit 3cc0b86

Browse files
committed
Implement a FileOutput interface
1 parent 57f6f0b commit 3cc0b86

File tree

1 file changed

+62
-3
lines changed

1 file changed

+62
-3
lines changed

replicate/stream.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import io
2+
import base64
3+
import httpx
14
from enum import Enum
25
from typing import (
36
TYPE_CHECKING,
@@ -9,9 +12,10 @@
912
Optional,
1013
Union,
1114
)
12-
15+
from contextlib import asynccontextmanager, contextmanager
1316
from typing_extensions import Unpack
1417

18+
1519
from replicate import identifier
1620
from replicate.exceptions import ReplicateError
1721

@@ -22,15 +26,70 @@
2226

2327

2428
if TYPE_CHECKING:
25-
import httpx
26-
2729
from replicate.client import Client
2830
from replicate.identifier import ModelVersionIdentifier
2931
from replicate.model import Model
3032
from replicate.prediction import Predictions
3133
from replicate.version import Version
3234

3335

36+
class FileOutputProvider:
37+
url: str
38+
client: "Client"
39+
40+
def __init__(self, url: str, client: "Client"):
41+
self.url = url
42+
self.client = client
43+
44+
def read(self) -> bytes:
45+
with self.stream() as file:
46+
return file.read()
47+
48+
@contextmanager
49+
def stream(self) -> Iterator["FileOutput"]:
50+
with self.client._client.stream("GET", self.url) as response:
51+
response.raise_for_status()
52+
yield FileOutput(response)
53+
54+
@asynccontextmanager
55+
async def astream(self) -> AsyncIterator["FileOutput"]:
56+
async with self.client._async_client.stream("GET", self.url) as response:
57+
response.raise_for_status()
58+
yield FileOutput(response)
59+
60+
async def aread(self) -> bytes:
61+
async with self.astream() as file:
62+
return await file.aread()
63+
64+
def __repr__(self) -> str:
65+
return self.url
66+
67+
68+
class FileOutput(httpx.ByteStream, httpx.AsyncByteStream):
69+
def __init__(self, response: httpx.Response):
70+
self.response = response
71+
72+
def __iter__(self) -> Iterator[bytes]:
73+
for bytes in self.response.iter_bytes():
74+
yield bytes
75+
76+
def close(self):
77+
return self.response.close()
78+
79+
def read(self):
80+
return self.response.read()
81+
82+
async def __aiter__(self) -> AsyncIterator[bytes]:
83+
async for bytes in self.response.aiter_bytes():
84+
yield bytes
85+
86+
async def aclose(self):
87+
return await self.response.aclose()
88+
89+
async def aread(self):
90+
return await self.response.aread()
91+
92+
3493
class ServerSentEvent(pydantic.BaseModel): # type: ignore
3594
"""
3695
A server-sent event.

0 commit comments

Comments
 (0)