Skip to content

Commit c41c8c5

Browse files
committed
Add support for hardware.list endpoint
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 61922f8 commit c41c8c5

File tree

5 files changed

+156
-8
lines changed

5 files changed

+156
-8
lines changed

replicate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
default_client = Client()
44
run = default_client.run
5+
hardware = default_client.hardware
6+
deployments = default_client.deployments
57
models = default_client.models
68
predictions = default_client.predictions
79
trainings = default_client.trainings
8-
deployments = default_client.deployments

replicate/client.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from replicate.__about__ import __version__
1818
from replicate.deployment import DeploymentCollection
1919
from replicate.exceptions import ModelError, ReplicateError
20+
from replicate.hardware import HardwareCollection
2021
from replicate.model import ModelCollection
2122
from replicate.prediction import PredictionCollection
2223
from replicate.schema import make_schema_backwards_compatible
@@ -83,6 +84,20 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
8384

8485
return resp
8586

87+
@property
88+
def deployments(self) -> DeploymentCollection:
89+
"""
90+
Namespace for operations related to deployments.
91+
"""
92+
return DeploymentCollection(client=self)
93+
94+
@property
95+
def hardware(self) -> HardwareCollection:
96+
"""
97+
Namespace for operations related to hardware.
98+
"""
99+
return HardwareCollection(client=self)
100+
86101
@property
87102
def models(self) -> ModelCollection:
88103
"""
@@ -104,13 +119,6 @@ def trainings(self) -> TrainingCollection:
104119
"""
105120
return TrainingCollection(client=self)
106121

107-
@property
108-
def deployments(self) -> DeploymentCollection:
109-
"""
110-
Namespace for operations related to deployments.
111-
"""
112-
return DeploymentCollection(client=self)
113-
114122
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401
115123
"""
116124
Run a model and wait for its output.

replicate/hardware.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Dict, List, Union
2+
3+
from replicate.base_model import BaseModel
4+
from replicate.collection import Collection
5+
6+
7+
class Hardware(BaseModel):
8+
"""
9+
Hardware for running a model on Replicate.
10+
"""
11+
12+
sku: str
13+
"""
14+
The SKU of the hardware.
15+
"""
16+
17+
name: str
18+
"""
19+
The name of the hardware.
20+
"""
21+
22+
23+
class HardwareCollection(Collection):
24+
"""
25+
Namespace for operations related to hardware.
26+
"""
27+
28+
model = Hardware
29+
30+
def list(self) -> List[Hardware]:
31+
"""
32+
List all public models.
33+
34+
Returns:
35+
A list of models.
36+
"""
37+
38+
resp = self._client._request("GET", "/v1/hardware")
39+
hardware = resp.json()
40+
return [self._prepare_model(obj) for obj in hardware]
41+
42+
def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware:
43+
if isinstance(attrs, BaseModel):
44+
attrs.id = attrs.sku
45+
elif isinstance(attrs, dict):
46+
attrs["id"] = attrs["sku"]
47+
48+
hardware = super()._prepare_model(attrs)
49+
50+
return hardware

tests/cassettes/hardware-list.yaml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
interactions:
2+
- request:
3+
body: ''
4+
headers:
5+
accept:
6+
- '*/*'
7+
accept-encoding:
8+
- gzip, deflate
9+
connection:
10+
- keep-alive
11+
host:
12+
- api.replicate.com
13+
user-agent:
14+
- replicate-python/0.15.5
15+
method: GET
16+
uri: https://api.replicate.com/v1/hardware
17+
response:
18+
content: '[{"sku":"cpu","name":"CPU"},{"sku":"gpu-t4","name":"Nvidia T4 GPU"},{"sku":"gpu-a40-small","name":"Nvidia
19+
A40 GPU"},{"sku":"gpu-a40-large","name":"Nvidia A40 (Large) GPU"}]'
20+
headers:
21+
CF-Cache-Status:
22+
- DYNAMIC
23+
CF-RAY:
24+
- 81fbfed29fe1c58a-SEA
25+
Connection:
26+
- keep-alive
27+
Content-Encoding:
28+
- gzip
29+
Content-Type:
30+
- application/json
31+
Date:
32+
- Thu, 02 Nov 2023 11:21:41 GMT
33+
Server:
34+
- cloudflare
35+
Strict-Transport-Security:
36+
- max-age=15552000
37+
Transfer-Encoding:
38+
- chunked
39+
allow:
40+
- OPTIONS, GET
41+
content-security-policy-report-only:
42+
- 'connect-src ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery
43+
https://*.rudderlabs.com https://*.rudderstack.com https://*.mux.com https://*.sentry.io;
44+
worker-src ''none''; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js;
45+
style-src ''report-sample'' ''self'' ''unsafe-inline''; font-src ''report-sample''
46+
''self'' data:; img-src ''report-sample'' ''self'' data: https://replicate.delivery
47+
https://*.replicate.delivery https://*.githubusercontent.com https://github.com;
48+
default-src ''self''; media-src ''report-sample'' ''self'' https://replicate.delivery
49+
https://*.replicate.delivery https://*.mux.com https://*.sentry.io; report-uri'
50+
cross-origin-opener-policy:
51+
- same-origin
52+
nel:
53+
- '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}'
54+
ratelimit-remaining:
55+
- '2999'
56+
ratelimit-reset:
57+
- '1'
58+
referrer-policy:
59+
- same-origin
60+
report-to:
61+
- '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D"}]}'
62+
reporting-endpoints:
63+
- heroku-nel=https://nel.heroku.com/reports?ts=1698924101&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=lMEXEYwO4dAJOZgt0b6ihblK5I4BwDBadrW6odcdYW8%3D
64+
vary:
65+
- Cookie, origin
66+
via:
67+
- 1.1 vegur, 1.1 google
68+
x-content-type-options:
69+
- nosniff
70+
x-frame-options:
71+
- DENY
72+
http_version: HTTP/1.1
73+
status_code: 200
74+
version: 1

tests/test_hardware.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import httpx
2+
import pytest
3+
import respx
4+
5+
import replicate
6+
7+
8+
@pytest.mark.vcr("hardware-list.yaml")
9+
@pytest.mark.asyncio
10+
async def test_hardware_list(mock_replicate_api_token):
11+
hardware = replicate.hardware.list()
12+
13+
assert hardware is not None
14+
assert isinstance(hardware, list)
15+
assert len(hardware) > 0

0 commit comments

Comments
 (0)