File tree Expand file tree Collapse file tree 2 files changed +53
-0
lines changed Expand file tree Collapse file tree 2 files changed +53
-0
lines changed Original file line number Diff line number Diff line change @@ -129,6 +129,8 @@ def create( # type: ignore
129
129
webhook : Optional [str ] = None ,
130
130
webhook_completed : Optional [str ] = None ,
131
131
webhook_events_filter : Optional [List [str ]] = None ,
132
+ * ,
133
+ stream : Optional [bool ] = None ,
132
134
** kwargs ,
133
135
) -> Prediction :
134
136
"""
@@ -157,6 +159,8 @@ def create( # type: ignore
157
159
body ["webhook_completed" ] = webhook_completed
158
160
if webhook_events_filter is not None :
159
161
body ["webhook_events_filter" ] = webhook_events_filter
162
+ if stream is True :
163
+ body ["stream" ] = "true"
160
164
161
165
resp = self ._client ._request (
162
166
"POST" ,
Original file line number Diff line number Diff line change @@ -94,6 +94,55 @@ def test_cancel():
94
94
assert rsp .call_count == 1
95
95
96
96
97
+ @responses .activate
98
+ def test_stream ():
99
+ client = create_client ()
100
+ version = create_version (client )
101
+
102
+ rsp = responses .post (
103
+ "https://api.replicate.com/v1/predictions" ,
104
+ match = [
105
+ matchers .json_params_matcher (
106
+ {
107
+ "version" : "v1" ,
108
+ "input" : {"text" : "world" },
109
+ "stream" : "true" ,
110
+ }
111
+ ),
112
+ ],
113
+ json = {
114
+ "id" : "p1" ,
115
+ "version" : "v1" ,
116
+ "urls" : {
117
+ "get" : "https://api.replicate.com/v1/predictions/p1" ,
118
+ "cancel" : "https://api.replicate.com/v1/predictions/p1/cancel" ,
119
+ "stream" : "https://streaming.api.replicate.com/v1/predictions/p1" ,
120
+ },
121
+ "created_at" : "2022-04-26T20:00:40.658234Z" ,
122
+ "completed_at" : "2022-04-26T20:02:27.648305Z" ,
123
+ "source" : "api" ,
124
+ "status" : "processing" ,
125
+ "input" : {"text" : "world" },
126
+ "output" : None ,
127
+ "error" : None ,
128
+ "logs" : "" ,
129
+ },
130
+ )
131
+
132
+ prediction = client .predictions .create (
133
+ version = version ,
134
+ input = {"text" : "world" },
135
+ stream = True ,
136
+ )
137
+
138
+ assert rsp .call_count == 1
139
+
140
+ assert (
141
+ prediction .urls ["stream" ]
142
+ == "https://streaming.api.replicate.com/v1/predictions/p1"
143
+ )
144
+
145
+
97
146
@responses .activate
98
147
def test_async_timings ():
99
148
client = create_client ()
You can’t perform that action at this time.
0 commit comments