Skip to content

Commit 18779eb

Browse files
authored
Add stream method on Prediction (#215)
Follow-up to #204 This PR adds a `stream` method on `Prediction`. In most cases, you'll want to call `replicate.stream` to get the stream directly. But if you have an existing prediction object and want to stream its output, you can use this method instead. [^1] [^1]: This rhymes with `run` being a shorthand for creating a prediction and calling `wait`. Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 6ad79b4 commit 18779eb

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

replicate/prediction.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
from typing_extensions import NotRequired, TypedDict, Unpack
77

8-
from replicate.exceptions import ModelError
8+
from replicate.exceptions import ModelError, ReplicateError
99
from replicate.files import upload_file
1010
from replicate.json import encode_json
1111
from replicate.pagination import Page
1212
from replicate.resource import Namespace, Resource
13+
from replicate.stream import EventSource
1314
from replicate.version import Version
1415

1516
try:
@@ -19,6 +20,7 @@
1920

2021
if TYPE_CHECKING:
2122
from replicate.client import Client
23+
from replicate.stream import ServerSentEvent
2224

2325

2426
class Prediction(Resource):
@@ -125,6 +127,25 @@ def wait(self) -> None:
125127
time.sleep(self._client.poll_interval)
126128
self.reload()
127129

130+
def stream(self) -> Optional[Iterator["ServerSentEvent"]]:
131+
"""
132+
Stream the prediction output.
133+
134+
Raises:
135+
ReplicateError: If the model does not support streaming.
136+
"""
137+
138+
url = self.urls and self.urls.get("stream", None)
139+
if not url or not isinstance(url, str):
140+
raise ReplicateError("Model does not support streaming")
141+
142+
headers = {}
143+
headers["Accept"] = "text/event-stream"
144+
headers["Cache-Control"] = "no-store"
145+
146+
with self._client._client.stream("GET", url, headers=headers) as response:
147+
yield from EventSource(response)
148+
128149
def cancel(self) -> None:
129150
"""
130151
Cancels a running prediction.

tests/test_stream.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,24 @@ async def test_stream(async_flag, record_mode):
3232

3333
assert len(events) > 0
3434
assert events[0].event == "output"
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_stream_prediction(record_mode):
39+
if record_mode == "none":
40+
return
41+
42+
version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
43+
44+
input = {
45+
"prompt": "Please write a haiku about llamas.",
46+
}
47+
48+
prediction = replicate.predictions.create(version=version, input=input)
49+
50+
events = []
51+
for event in prediction.stream():
52+
events.append(event)
53+
54+
assert len(events) > 0
55+
assert events[0].event == "output"

0 commit comments

Comments
 (0)