Skip to content

Add stream method on Prediction #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

from typing_extensions import NotRequired, TypedDict, Unpack

from replicate.exceptions import ModelError
from replicate.exceptions import ModelError, ReplicateError
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
from replicate.stream import EventSource
from replicate.version import Version

try:
Expand All @@ -19,6 +20,7 @@

if TYPE_CHECKING:
from replicate.client import Client
from replicate.stream import ServerSentEvent


class Prediction(Resource):
Expand Down Expand Up @@ -125,6 +127,25 @@ def wait(self) -> None:
time.sleep(self._client.poll_interval)
self.reload()

def stream(self) -> Optional[Iterator["ServerSentEvent"]]:
"""
Stream the prediction output.

Raises:
ReplicateError: If the model does not support streaming.
"""

url = self.urls and self.urls.get("stream", None)
if not url or not isinstance(url, str):
raise ReplicateError("Model does not support streaming")

headers = {}
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"

with self._client._client.stream("GET", url, headers=headers) as response:
yield from EventSource(response)

def cancel(self) -> None:
"""
Cancels a running prediction.
Expand Down
21 changes: 21 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,24 @@ async def test_stream(async_flag, record_mode):

assert len(events) > 0
assert events[0].event == "output"


@pytest.mark.asyncio
async def test_stream_prediction(record_mode):
if record_mode == "none":
return

version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"

input = {
"prompt": "Please write a haiku about llamas.",
}

prediction = replicate.predictions.create(version=version, input=input)

events = []
for event in prediction.stream():
events.append(event)

assert len(events) > 0
assert events[0].event == "output"