Skip to content

Commit cc7101f

Browse files
authored
Add support for trainings.list endpoint (#88)
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent bf0eb7d commit cc7101f

File tree

2 files changed

+182
-1
lines changed

2 files changed

+182
-1
lines changed

replicate/training.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@ class TrainingCollection(Collection):
3030
model = Training
3131

3232
def list(self) -> List[Training]:
33-
raise NotImplementedError()
33+
resp = self._client._request("GET", "/v1/trainings")
34+
# TODO: paginate
35+
trainings = resp.json()["results"]
36+
for training in trainings:
37+
# HACK: resolve this? make it lazy somehow?
38+
del training["version"]
39+
return [self.prepare_model(obj) for obj in trainings]
3440

3541
def get(self, id: str) -> Training:
3642
resp = self._client._request(

tests/test_training.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import responses
2+
from responses import matchers
3+
4+
from .factories import create_client, create_version
5+
6+
7+
@responses.activate
8+
def test_create_works_with_webhooks():
9+
client = create_client()
10+
version = create_version(client)
11+
12+
rsp = responses.post(
13+
"https://api.replicate.com/v1/models/owner/model/versions/v1/trainings",
14+
match=[
15+
matchers.json_params_matcher(
16+
{
17+
"input": {"data": "..."},
18+
"destination": "new_owner/new_model",
19+
"webhook": "https://example.com/webhook",
20+
"webhook_events_filter": ["completed"],
21+
}
22+
),
23+
],
24+
json={
25+
"id": "t1",
26+
"version": "v1",
27+
"urls": {
28+
"get": "https://api.replicate.com/v1/trainings/t1",
29+
"cancel": "https://api.replicate.com/v1/trainings/t1/cancel",
30+
},
31+
"created_at": "2022-04-26T20:00:40.658234Z",
32+
"completed_at": "2022-04-26T20:02:27.648305Z",
33+
"source": "api",
34+
"status": "processing",
35+
"input": {"data": "..."},
36+
"output": None,
37+
"error": None,
38+
"logs": "",
39+
},
40+
)
41+
42+
client.trainings.create(
43+
version=f"owner/model:{version.id}",
44+
input={"data": "..."},
45+
destination="new_owner/new_model",
46+
webhook="https://example.com/webhook",
47+
webhook_events_filter=["completed"],
48+
)
49+
50+
assert rsp.call_count == 1
51+
52+
53+
@responses.activate
54+
def test_cancel():
55+
client = create_client()
56+
version = create_version(client)
57+
58+
responses.post(
59+
"https://api.replicate.com/v1/models/owner/model/versions/v1/trainings",
60+
match=[
61+
matchers.json_params_matcher(
62+
{
63+
"input": {"data": "..."},
64+
"destination": "new_owner/new_model",
65+
"webhook": "https://example.com/webhook",
66+
"webhook_events_filter": ["completed"],
67+
}
68+
),
69+
],
70+
json={
71+
"id": "t1",
72+
"version": "v1",
73+
"urls": {
74+
"get": "https://api.replicate.com/v1/trainings/t1",
75+
"cancel": "https://api.replicate.com/v1/trainings/t1/cancel",
76+
},
77+
"created_at": "2022-04-26T20:00:40.658234Z",
78+
"completed_at": "2022-04-26T20:02:27.648305Z",
79+
"source": "api",
80+
"status": "processing",
81+
"input": {"data": "..."},
82+
"output": None,
83+
"error": None,
84+
"logs": "",
85+
},
86+
)
87+
88+
training = client.trainings.create(
89+
version=f"owner/model:{version.id}",
90+
input={"data": "..."},
91+
destination="new_owner/new_model",
92+
webhook="https://example.com/webhook",
93+
webhook_events_filter=["completed"],
94+
)
95+
96+
rsp = responses.post("https://api.replicate.com/v1/trainings/t1/cancel", json={})
97+
training.cancel()
98+
assert rsp.call_count == 1
99+
100+
101+
@responses.activate
102+
def test_async_timings():
103+
client = create_client()
104+
version = create_version(client)
105+
106+
responses.post(
107+
"https://api.replicate.com/v1/models/owner/model/versions/v1/trainings",
108+
match=[
109+
matchers.json_params_matcher(
110+
{
111+
"input": {"data": "..."},
112+
"destination": "new_owner/new_model",
113+
"webhook": "https://example.com/webhook",
114+
"webhook_events_filter": ["completed"],
115+
}
116+
),
117+
],
118+
json={
119+
"id": "t1",
120+
"version": "v1",
121+
"urls": {
122+
"get": "https://api.replicate.com/v1/trainings/t1",
123+
"cancel": "https://api.replicate.com/v1/trainings/t1/cancel",
124+
},
125+
"created_at": "2022-04-26T20:00:40.658234Z",
126+
"source": "api",
127+
"status": "processing",
128+
"input": {"data": "..."},
129+
"output": None,
130+
"error": None,
131+
"logs": "",
132+
},
133+
)
134+
135+
responses.get(
136+
"https://api.replicate.com/v1/trainings/t1",
137+
json={
138+
"id": "t1",
139+
"version": "v1",
140+
"urls": {
141+
"get": "https://api.replicate.com/v1/trainings/t1",
142+
"cancel": "https://api.replicate.com/v1/trainings/t1/cancel",
143+
},
144+
"created_at": "2022-04-26T20:00:40.658234Z",
145+
"completed_at": "2022-04-26T20:02:27.648305Z",
146+
"source": "api",
147+
"status": "succeeded",
148+
"input": {"data": "..."},
149+
"output": {
150+
"weights": "https://delivery.replicate.com/weights.tgz",
151+
"version": "v2",
152+
},
153+
"error": None,
154+
"logs": "",
155+
},
156+
)
157+
158+
training = client.trainings.create(
159+
version=f"owner/model:{version.id}",
160+
input={"data": "..."},
161+
destination="new_owner/new_model",
162+
webhook="https://example.com/webhook",
163+
webhook_events_filter=["completed"],
164+
)
165+
166+
assert training.created_at == "2022-04-26T20:00:40.658234Z"
167+
assert training.completed_at is None
168+
assert training.output is None
169+
170+
# trainings don't have a wait method, so simulate it by calling reload
171+
training.reload()
172+
assert training.created_at == "2022-04-26T20:00:40.658234Z"
173+
assert training.completed_at == "2022-04-26T20:02:27.648305Z"
174+
assert training.output["weights"] == "https://delivery.replicate.com/weights.tgz"
175+
assert training.output["version"] == "v2"

0 commit comments

Comments
 (0)