Skip to content

Commit fed75e2

Browse files
committed
Add support for models.create endpoint
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent c41c8c5 commit fed75e2

File tree

4 files changed

+168
-2
lines changed

4 files changed

+168
-2
lines changed

replicate/model.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,63 @@ def get(self, key: str) -> Model:
150150
resp = self._client._request("GET", f"/v1/models/{key}")
151151
return self._prepare_model(resp.json())
152152

153+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
154+
self,
155+
owner: str,
156+
name: str,
157+
*,
158+
visibility: str,
159+
hardware: str,
160+
description: Optional[str] = None,
161+
github_url: Optional[str] = None,
162+
paper_url: Optional[str] = None,
163+
license_url: Optional[str] = None,
164+
cover_image_url: Optional[str] = None,
165+
) -> Model:
166+
"""
167+
Create a model.
168+
169+
Args:
170+
owner: The name of the user or organization that will own the model.
171+
name: The name of the model.
172+
visibility: Whether the model should be public or private.
173+
hardware: The SKU for the hardware used to run the model. Possible values can be found by calling `replicate.hardware.list()`.
174+
description: A description of the model.
175+
github_url: A URL for the model's source code on GitHub.
176+
paper_url: A URL for the model's paper.
177+
license_url: A URL for the model's license.
178+
cover_image_url: A URL for the model's cover image.
179+
180+
Returns:
181+
The created model.
182+
"""
183+
184+
body = {
185+
"owner": owner,
186+
"name": name,
187+
"visibility": visibility,
188+
"hardware": hardware,
189+
}
190+
191+
if description is not None:
192+
body["description"] = description
193+
194+
if github_url is not None:
195+
body["github_url"] = github_url
196+
197+
if paper_url is not None:
198+
body["paper_url"] = paper_url
199+
200+
if license_url is not None:
201+
body["license_url"] = license_url
202+
203+
if cover_image_url is not None:
204+
body["cover_image_url"] = cover_image_url
205+
206+
resp = self._client._request("POST", "/v1/models", json=body)
207+
208+
return self._prepare_model(resp.json())
209+
153210
def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
154211
if isinstance(attrs, BaseModel):
155212
attrs.id = f"{attrs.owner}/{attrs.name}"

tests/cassettes/models-create.yaml

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
interactions:
2+
- request:
3+
body: '{"owner": "test", "name": "python-example", "visibility": "private", "hardware":
4+
"cpu", "description": "An example model"}'
5+
headers:
6+
accept:
7+
- '*/*'
8+
accept-encoding:
9+
- gzip, deflate
10+
connection:
11+
- keep-alive
12+
content-length:
13+
- '123'
14+
content-type:
15+
- application/json
16+
host:
17+
- api.replicate.com
18+
user-agent:
19+
- replicate-python/0.15.6
20+
method: POST
21+
uri: https://api.replicate.com/v1/models
22+
response:
23+
content: '{"url": "https://replicate.com/test/python-example", "owner": "test",
24+
"name": "python-example", "description": "An example model", "visibility": "private",
25+
"github_url": null, "paper_url": null, "license_url": null, "run_count": 0,
26+
"cover_image_url": null, "default_example": null, "latest_version": null}'
27+
headers:
28+
CF-Cache-Status:
29+
- DYNAMIC
30+
CF-RAY:
31+
- 81ff2e098ec0eb5b-SEA
32+
Connection:
33+
- keep-alive
34+
Content-Length:
35+
- '307'
36+
Content-Type:
37+
- application/json
38+
Date:
39+
- Thu, 02 Nov 2023 20:38:12 GMT
40+
Server:
41+
- cloudflare
42+
Strict-Transport-Security:
43+
- max-age=15552000
44+
allow:
45+
- GET, POST, HEAD, OPTIONS
46+
content-security-policy-report-only:
47+
- 'font-src ''report-sample'' ''self'' data:; img-src ''report-sample'' ''self''
48+
data: https://replicate.delivery https://*.replicate.delivery https://*.githubusercontent.com
49+
https://github.com; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js;
50+
style-src ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample''
51+
''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com
52+
https://*.rudderstack.com https://*.mux.com https://*.sentry.io; worker-src
53+
''none''; media-src ''report-sample'' ''self'' https://replicate.delivery
54+
https://*.replicate.delivery https://*.mux.com https://*.sentry.io; default-src
55+
''self''; report-uri'
56+
cross-origin-opener-policy:
57+
- same-origin
58+
nel:
59+
- '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}'
60+
ratelimit-remaining:
61+
- '2999'
62+
ratelimit-reset:
63+
- '1'
64+
referrer-policy:
65+
- same-origin
66+
report-to:
67+
- '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D"}]}'
68+
reporting-endpoints:
69+
- heroku-nel=https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D
70+
vary:
71+
- Cookie, origin
72+
via:
73+
- 1.1 vegur, 1.1 google
74+
x-content-type-options:
75+
- nosniff
76+
x-frame-options:
77+
- DENY
78+
http_version: HTTP/1.1
79+
status_code: 201
80+
version: 1

tests/test_hardware.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import httpx
21
import pytest
3-
import respx
42

53
import replicate
64

tests/test_model.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,34 @@ async def test_models_list(mock_replicate_api_token):
2727
assert models[0].owner is not None
2828
assert models[0].name is not None
2929
assert models[0].visibility == "public"
30+
31+
32+
@pytest.mark.vcr("models-create.yaml")
33+
@pytest.mark.asyncio
34+
async def test_models_create(mock_replicate_api_token):
35+
model = replicate.models.create(
36+
owner="test",
37+
name="python-example",
38+
visibility="private",
39+
hardware="cpu",
40+
description="An example model",
41+
)
42+
43+
assert model.owner == "test"
44+
assert model.name == "python-example"
45+
assert model.visibility == "private"
46+
47+
48+
@pytest.mark.vcr("models-create.yaml")
49+
@pytest.mark.asyncio
50+
async def test_models_create_with_positional_arguments(mock_replicate_api_token):
51+
model = replicate.models.create(
52+
"test",
53+
"python-example",
54+
visibility="private",
55+
hardware="cpu",
56+
)
57+
58+
assert model.owner == "test"
59+
assert model.name == "python-example"
60+
assert model.visibility == "private"

0 commit comments

Comments
 (0)