Skip to content

Commit 70f2eb0

Browse files
authored
Add support for models.create and hardware.list endpoints (#184)
See https://replicate.com/docs/reference/http#models.create --------- Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 988ec27 commit 70f2eb0

File tree

9 files changed

+345
-8
lines changed

9 files changed

+345
-8
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,29 @@ urlretrieve(out[0], "/tmp/out.png")
178178
background = Image.open("/tmp/out.png")
179179
```
180180

181+
## Create a model
182+
183+
You can create a model for a user or organization
184+
with a given name, visibility, and hardware SKU:
185+
186+
```python
187+
import replicate
188+
189+
model = replicate.models.create(
190+
owner="your-username",
191+
name="my-model",
192+
visibility="public",
193+
hardware="gpu-a40-large"
194+
)
195+
```
196+
197+
Here's how to list of all the available hardware for running models on Replicate:
198+
199+
```python
200+
>>> [hw.sku for hw in replicate.hardware.list()]
201+
['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large']
202+
```
203+
181204
## Development
182205

183206
See [CONTRIBUTING.md](CONTRIBUTING.md)

replicate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
default_client = Client()
44
run = default_client.run
5+
hardware = default_client.hardware
6+
deployments = default_client.deployments
57
models = default_client.models
68
predictions = default_client.predictions
79
trainings = default_client.trainings
8-
deployments = default_client.deployments

replicate/client.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from replicate.__about__ import __version__
1818
from replicate.deployment import DeploymentCollection
1919
from replicate.exceptions import ModelError, ReplicateError
20+
from replicate.hardware import HardwareCollection
2021
from replicate.model import ModelCollection
2122
from replicate.prediction import PredictionCollection
2223
from replicate.schema import make_schema_backwards_compatible
@@ -83,6 +84,20 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
8384

8485
return resp
8586

87+
@property
88+
def deployments(self) -> DeploymentCollection:
89+
"""
90+
Namespace for operations related to deployments.
91+
"""
92+
return DeploymentCollection(client=self)
93+
94+
@property
95+
def hardware(self) -> HardwareCollection:
96+
"""
97+
Namespace for operations related to hardware.
98+
"""
99+
return HardwareCollection(client=self)
100+
86101
@property
87102
def models(self) -> ModelCollection:
88103
"""
@@ -104,13 +119,6 @@ def trainings(self) -> TrainingCollection:
104119
"""
105120
return TrainingCollection(client=self)
106121

107-
@property
108-
def deployments(self) -> DeploymentCollection:
109-
"""
110-
Namespace for operations related to deployments.
111-
"""
112-
return DeploymentCollection(client=self)
113-
114122
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401
115123
"""
116124
Run a model and wait for its output.

replicate/hardware.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Dict, List, Union
2+
3+
from replicate.base_model import BaseModel
4+
from replicate.collection import Collection
5+
6+
7+
class Hardware(BaseModel):
8+
"""
9+
Hardware for running a model on Replicate.
10+
"""
11+
12+
sku: str
13+
"""
14+
The SKU of the hardware.
15+
"""
16+
17+
name: str
18+
"""
19+
The name of the hardware.
20+
"""
21+
22+
23+
class HardwareCollection(Collection):
24+
"""
25+
Namespace for operations related to hardware.
26+
"""
27+
28+
model = Hardware
29+
30+
def list(self) -> List[Hardware]:
31+
"""
32+
List all public models.
33+
34+
Returns:
35+
A list of models.
36+
"""
37+
38+
resp = self._client._request("GET", "/v1/hardware")
39+
hardware = resp.json()
40+
return [self._prepare_model(obj) for obj in hardware]
41+
42+
def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware:
43+
if isinstance(attrs, BaseModel):
44+
attrs.id = attrs.sku
45+
elif isinstance(attrs, dict):
46+
attrs["id"] = attrs["sku"]
47+
48+
hardware = super()._prepare_model(attrs)
49+
50+
return hardware

replicate/model.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,63 @@ def get(self, key: str) -> Model:
150150
resp = self._client._request("GET", f"/v1/models/{key}")
151151
return self._prepare_model(resp.json())
152152

153+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
154+
self,
155+
owner: str,
156+
name: str,
157+
*,
158+
visibility: str,
159+
hardware: str,
160+
description: Optional[str] = None,
161+
github_url: Optional[str] = None,
162+
paper_url: Optional[str] = None,
163+
license_url: Optional[str] = None,
164+
cover_image_url: Optional[str] = None,
165+
) -> Model:
166+
"""
167+
Create a model.
168+
169+
Args:
170+
owner: The name of the user or organization that will own the model.
171+
name: The name of the model.
172+
visibility: Whether the model should be public or private.
173+
hardware: The SKU for the hardware used to run the model. Possible values can be found by calling `replicate.hardware.list()`.
174+
description: A description of the model.
175+
github_url: A URL for the model's source code on GitHub.
176+
paper_url: A URL for the model's paper.
177+
license_url: A URL for the model's license.
178+
cover_image_url: A URL for the model's cover image.
179+
180+
Returns:
181+
The created model.
182+
"""
183+
184+
body = {
185+
"owner": owner,
186+
"name": name,
187+
"visibility": visibility,
188+
"hardware": hardware,
189+
}
190+
191+
if description is not None:
192+
body["description"] = description
193+
194+
if github_url is not None:
195+
body["github_url"] = github_url
196+
197+
if paper_url is not None:
198+
body["paper_url"] = paper_url
199+
200+
if license_url is not None:
201+
body["license_url"] = license_url
202+
203+
if cover_image_url is not None:
204+
body["cover_image_url"] = cover_image_url
205+
206+
resp = self._client._request("POST", "/v1/models", json=body)
207+
208+
return self._prepare_model(resp.json())
209+
153210
def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
154211
if isinstance(attrs, BaseModel):
155212
attrs.id = f"{attrs.owner}/{attrs.name}"

tests/cassettes/hardware-list.yaml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
interactions:
2+
- request:
3+
body: ''
4+
headers:
5+
accept:
6+
- '*/*'
7+
accept-encoding:
8+
- gzip, deflate
9+
connection:
10+
- keep-alive
11+
host:
12+
- api.replicate.com
13+
user-agent:
14+
- replicate-python/0.15.5
15+
method: GET
16+
uri: https://api.replicate.com/v1/hardware
17+
response:
18+
content: '[{"sku":"cpu","name":"CPU"},{"sku":"gpu-t4","name":"Nvidia T4 GPU"},{"sku":"gpu-a40-small","name":"Nvidia
19+
A40 GPU"},{"sku":"gpu-a40-large","name":"Nvidia A40 (Large) GPU"}]'
20+
headers:
21+
CF-Cache-Status:
22+
- DYNAMIC
23+
CF-RAY:
24+
- 81fbfed29fe1c58a-SEA
25+
Connection:
26+
- keep-alive
27+
Content-Encoding:
28+
- gzip
29+
Content-Type:
30+
- application/json
31+
Date:
32+
- Thu, 02 Nov 2023 11:21:41 GMT
33+
Server:
34+
- cloudflare
35+
Strict-Transport-Security:
36+
- max-age=15552000
37+
Transfer-Encoding:
38+
- chunked
39+
allow:
40+
- OPTIONS, GET
41+
content-security-policy-report-only:
42+
- 'connect-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery
43+
https://*.rudderlabs.com https://*.rudderstack.com https://*.mux.com https://*.sentry.io;
44+
worker-src ''none''; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js;
45+
style-src ''report-sample'' ''self'' ''unsafe-inline''; font-src ''report-sample''
46+
''self'' data:; img-src ''report-sample'' ''self'' data: https://replicate.delivery
47+
https://*.replicate.delivery https://*.githubusercontent.com https://github.com;
48+
default-src ''self''; media-src ''report-sample'' ''self'' https://replicate.delivery
49+
https://*.replicate.delivery https://*.mux.com https://*.sentry.io; report-uri'
50+
cross-origin-opener-policy:
51+
- same-origin
52+
nel:
53+
- '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}'
54+
ratelimit-remaining:
55+
- '2999'
56+
ratelimit-reset:
57+
- '1'
58+
referrer-policy:
59+
- same-origin
60+
report-to:
61+
- '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D"}]}'
62+
reporting-endpoints:
63+
- heroku-nel=https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D
64+
vary:
65+
- Cookie, origin
66+
via:
67+
- 1.1 vegur, 1.1 google
68+
x-content-type-options:
69+
- nosniff
70+
x-frame-options:
71+
- DENY
72+
http_version: HTTP/1.1
73+
status_code: 200
74+
version: 1

tests/cassettes/models-create.yaml

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
interactions:
2+
- request:
3+
body: '{"owner": "test", "name": "python-example", "visibility": "private", "hardware":
4+
"cpu", "description": "An example model"}'
5+
headers:
6+
accept:
7+
- '*/*'
8+
accept-encoding:
9+
- gzip, deflate
10+
connection:
11+
- keep-alive
12+
content-length:
13+
- '123'
14+
content-type:
15+
- application/json
16+
host:
17+
- api.replicate.com
18+
user-agent:
19+
- replicate-python/0.15.6
20+
method: POST
21+
uri: https://api.replicate.com/v1/models
22+
response:
23+
content: '{"url": "https://replicate.com/test/python-example", "owner": "test",
24+
"name": "python-example", "description": "An example model", "visibility": "private",
25+
"github_url": null, "paper_url": null, "license_url": null, "run_count": 0,
26+
"cover_image_url": null, "default_example": null, "latest_version": null}'
27+
headers:
28+
CF-Cache-Status:
29+
- DYNAMIC
30+
CF-RAY:
31+
- 81ff2e098ec0eb5b-SEA
32+
Connection:
33+
- keep-alive
34+
Content-Length:
35+
- '307'
36+
Content-Type:
37+
- application/json
38+
Date:
39+
- Thu, 02 Nov 2023 20:38:12 GMT
40+
Server:
41+
- cloudflare
42+
Strict-Transport-Security:
43+
- max-age=15552000
44+
allow:
45+
- GET, POST, HEAD, OPTIONS
46+
content-security-policy-report-only:
47+
- 'font-src ''report-sample'' ''self'' data:; img-src ''report-sample'' ''self''
48+
data: https://replicate.delivery https://*.replicate.delivery https://*.githubusercontent.com
49+
https://github.com; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js;
50+
style-src ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample''
51+
''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com
52+
https://*.rudderstack.com https://*.mux.com https://*.sentry.io; worker-src
53+
''none''; media-src ''report-sample'' ''self'' https://replicate.delivery
54+
https://*.replicate.delivery https://*.mux.com https://*.sentry.io; default-src
55+
''self''; report-uri'
56+
cross-origin-opener-policy:
57+
- same-origin
58+
nel:
59+
- '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}'
60+
ratelimit-remaining:
61+
- '2999'
62+
ratelimit-reset:
63+
- '1'
64+
referrer-policy:
65+
- same-origin
66+
report-to:
67+
- '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D"}]}'
68+
reporting-endpoints:
69+
- heroku-nel=https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D
70+
vary:
71+
- Cookie, origin
72+
via:
73+
- 1.1 vegur, 1.1 google
74+
x-content-type-options:
75+
- nosniff
76+
x-frame-options:
77+
- DENY
78+
http_version: HTTP/1.1
79+
status_code: 201
80+
version: 1

tests/test_hardware.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
3+
import replicate
4+
5+
6+
@pytest.mark.vcr("hardware-list.yaml")
7+
@pytest.mark.asyncio
8+
async def test_hardware_list(mock_replicate_api_token):
9+
hardware = replicate.hardware.list()
10+
11+
assert hardware is not None
12+
assert isinstance(hardware, list)
13+
assert len(hardware) > 0

0 commit comments

Comments
 (0)