Skip to content

Commit 013d180

Browse files
authored
Add stream parameter to predictions.create (#131)
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 5fc83da commit 013d180

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

replicate/prediction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def create( # type: ignore
129129
webhook: Optional[str] = None,
130130
webhook_completed: Optional[str] = None,
131131
webhook_events_filter: Optional[List[str]] = None,
132+
*,
133+
stream: Optional[bool] = None,
132134
**kwargs,
133135
) -> Prediction:
134136
"""
@@ -157,6 +159,8 @@ def create( # type: ignore
157159
body["webhook_completed"] = webhook_completed
158160
if webhook_events_filter is not None:
159161
body["webhook_events_filter"] = webhook_events_filter
162+
if stream is True:
163+
body["stream"] = "true"
160164

161165
resp = self._client._request(
162166
"POST",

tests/test_prediction.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,55 @@ def test_cancel():
9494
assert rsp.call_count == 1
9595

9696

97+
@responses.activate
98+
def test_stream():
99+
client = create_client()
100+
version = create_version(client)
101+
102+
rsp = responses.post(
103+
"https://api.replicate.com/v1/predictions",
104+
match=[
105+
matchers.json_params_matcher(
106+
{
107+
"version": "v1",
108+
"input": {"text": "world"},
109+
"stream": "true",
110+
}
111+
),
112+
],
113+
json={
114+
"id": "p1",
115+
"version": "v1",
116+
"urls": {
117+
"get": "https://api.replicate.com/v1/predictions/p1",
118+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
119+
"stream": "https://streaming.api.replicate.com/v1/predictions/p1",
120+
},
121+
"created_at": "2022-04-26T20:00:40.658234Z",
122+
"completed_at": "2022-04-26T20:02:27.648305Z",
123+
"source": "api",
124+
"status": "processing",
125+
"input": {"text": "world"},
126+
"output": None,
127+
"error": None,
128+
"logs": "",
129+
},
130+
)
131+
132+
prediction = client.predictions.create(
133+
version=version,
134+
input={"text": "world"},
135+
stream=True,
136+
)
137+
138+
assert rsp.call_count == 1
139+
140+
assert (
141+
prediction.urls["stream"]
142+
== "https://streaming.api.replicate.com/v1/predictions/p1"
143+
)
144+
145+
97146
@responses.activate
98147
def test_async_timings():
99148
client = create_client()

0 commit comments

Comments
 (0)