|
5 | 5 |
|
6 | 6 | from typing_extensions import NotRequired, TypedDict, Unpack
|
7 | 7 |
|
8 |
| -from replicate.exceptions import ModelError |
| 8 | +from replicate.exceptions import ModelError, ReplicateError |
9 | 9 | from replicate.files import upload_file
|
10 | 10 | from replicate.json import encode_json
|
11 | 11 | from replicate.pagination import Page
|
12 | 12 | from replicate.resource import Namespace, Resource
|
| 13 | +from replicate.stream import EventSource |
13 | 14 | from replicate.version import Version
|
14 | 15 |
|
15 | 16 | try:
|
|
19 | 20 |
|
20 | 21 | if TYPE_CHECKING:
|
21 | 22 | from replicate.client import Client
|
| 23 | + from replicate.stream import ServerSentEvent |
22 | 24 |
|
23 | 25 |
|
24 | 26 | class Prediction(Resource):
|
@@ -125,6 +127,25 @@ def wait(self) -> None:
|
125 | 127 | time.sleep(self._client.poll_interval)
|
126 | 128 | self.reload()
|
127 | 129 |
|
| 130 | + def stream(self) -> Optional[Iterator["ServerSentEvent"]]: |
| 131 | + """ |
| 132 | + Stream the prediction output. |
| 133 | +
|
| 134 | + Raises: |
| 135 | + ReplicateError: If the model does not support streaming. |
| 136 | + """ |
| 137 | + |
| 138 | + url = self.urls and self.urls.get("stream", None) |
| 139 | + if not url or not isinstance(url, str): |
| 140 | + raise ReplicateError("Model does not support streaming") |
| 141 | + |
| 142 | + headers = {} |
| 143 | + headers["Accept"] = "text/event-stream" |
| 144 | + headers["Cache-Control"] = "no-store" |
| 145 | + |
| 146 | + with self._client._client.stream("GET", url, headers=headers) as response: |
| 147 | + yield from EventSource(response) |
| 148 | + |
128 | 149 | def cancel(self) -> None:
|
129 | 150 | """
|
130 | 151 | Cancels a running prediction.
|
|
0 commit comments