Skip to content

Commit fd3074f

Browse files
committed
Add support for new webhooks
Signed-off-by: Ben Firshman <ben@firshman.com>
1 parent 3a1a933 commit fd3074f

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

replicate/prediction.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,21 @@ def create(
5959
self,
6060
version: Version,
6161
input: Dict[str, Any],
62+
webhook: Optional[str] = None,
6263
webhook_completed: Optional[str] = None,
64+
webhook_events_filter: Optional[List[str]] = None,
6365
) -> Prediction:
6466
input = encode_json(input, upload_file=upload_file)
6567
body = {
6668
"version": version.id,
6769
"input": input,
6870
}
71+
if webhook is not None:
72+
body["webhook"] = webhook
6973
if webhook_completed is not None:
7074
body["webhook_completed"] = webhook_completed
75+
if webhook_events_filter is not None:
76+
body["webhook_events_filter"] = webhook_events_filter
7177

7278
resp = self._client._request(
7379
"POST",

tests/test_prediction.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,56 @@
1-
import replicate
21
import responses
32
from responses import matchers
43

4+
import replicate
5+
56
from .factories import create_client, create_version
67

78

9+
@responses.activate
10+
def test_create_works_with_webhooks():
11+
client = create_client()
12+
version = create_version(client)
13+
14+
rsp = responses.post(
15+
"https://api.replicate.com/v1/predictions",
16+
match=[
17+
matchers.json_params_matcher(
18+
{
19+
"version": "v1",
20+
"input": {"text": "world"},
21+
"webhook": "https://example.com/webhook",
22+
"webhook_events_filter": ["completed"],
23+
}
24+
),
25+
],
26+
json={
27+
"id": "p1",
28+
"version": "v1",
29+
"urls": {
30+
"get": "https://api.replicate.com/v1/predictions/p1",
31+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
32+
},
33+
"created_at": "2022-04-26T20:00:40.658234Z",
34+
"completed_at": "2022-04-26T20:02:27.648305Z",
35+
"source": "api",
36+
"status": "processing",
37+
"input": {"text": "world"},
38+
"output": None,
39+
"error": None,
40+
"logs": "",
41+
},
42+
)
43+
44+
prediction = client.predictions.create(
45+
version=version,
46+
input={"text": "world"},
47+
webhook="https://example.com/webhook",
48+
webhook_events_filter=["completed"],
49+
)
50+
51+
assert rsp.call_count == 1
52+
53+
854
@responses.activate
955
def test_cancel():
1056
client = create_client()

0 commit comments

Comments
 (0)