Skip to content

Commit b16235f

Browse files
committed
Add Deployment model and deployments collection property on Client
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 2b1da58 commit b16235f

File tree

4 files changed

+193
-0
lines changed

4 files changed

+193
-0
lines changed

replicate/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
models = default_client.models
66
predictions = default_client.predictions
77
trainings = default_client.trainings
8+
deployments = default_client.deployments

replicate/client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from requests.cookies import RequestsCookieJar
99

1010
from replicate.__about__ import __version__
11+
from replicate.deployment import DeploymentCollection
1112
from replicate.exceptions import ModelError, ReplicateError
1213
from replicate.model import ModelCollection
1314
from replicate.prediction import PredictionCollection
@@ -113,6 +114,10 @@ def predictions(self) -> PredictionCollection:
113114
def trainings(self) -> TrainingCollection:
114115
return TrainingCollection(client=self)
115116

117+
@property
118+
def deployments(self) -> DeploymentCollection:
119+
return DeploymentCollection(client=self)
120+
116121
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
117122
"""
118123
Run a model and wait for its output.

replicate/deployment.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2+
3+
from replicate.base_model import BaseModel
4+
from replicate.collection import Collection
5+
from replicate.files import upload_file
6+
from replicate.json import encode_json
7+
from replicate.prediction import Prediction
8+
9+
if TYPE_CHECKING:
10+
from replicate.client import Client
11+
12+
13+
class Deployment(BaseModel):
14+
"""
15+
A deployment of a model hosted on Replicate.
16+
"""
17+
18+
username: str
19+
"""
20+
The name of the user or organization that owns the deployment.
21+
"""
22+
23+
name: str
24+
"""
25+
The name of the deployment.
26+
"""
27+
28+
@property
29+
def predictions(self) -> "DeploymentPredictionCollection":
30+
"""
31+
Get the predictions for this deployment.
32+
"""
33+
34+
return DeploymentPredictionCollection(client=self._client, deployment=self)
35+
36+
37+
class DeploymentCollection(Collection):
38+
model = Deployment
39+
40+
def list(self) -> List[Deployment]:
41+
raise NotImplementedError()
42+
43+
def get(self, name: str) -> Deployment:
44+
"""
45+
Get a deployment by name.
46+
47+
Args:
48+
name: The name of the deployment, in the format `owner/model-name`.
49+
Returns:
50+
The model.
51+
"""
52+
53+
# TODO: fetch model from server
54+
# TODO: support permanent IDs
55+
username, name = name.split("/")
56+
return self.prepare_model({"username": username, "name": name})
57+
58+
def create(self, **kwargs) -> Deployment:
59+
raise NotImplementedError()
60+
61+
def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
62+
if isinstance(attrs, BaseModel):
63+
attrs.id = f"{attrs.username}/{attrs.name}"
64+
elif isinstance(attrs, dict):
65+
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
66+
return super().prepare_model(attrs)
67+
68+
69+
class DeploymentPredictionCollection(Collection):
70+
model = Prediction
71+
72+
def __init__(self, client: "Client", deployment: Deployment) -> None:
73+
super().__init__(client=client)
74+
self._deployment = deployment
75+
76+
def list(self) -> List[Prediction]:
77+
raise NotImplementedError()
78+
79+
def get(self, id: str) -> Prediction:
80+
"""
81+
Get a prediction by ID.
82+
83+
Args:
84+
id: The ID of the prediction.
85+
Returns:
86+
Prediction: The prediction object.
87+
"""
88+
89+
resp = self._client._request("GET", f"/v1/predictions/{id}")
90+
obj = resp.json()
91+
# HACK: resolve this? make it lazy somehow?
92+
del obj["version"]
93+
return self.prepare_model(obj)
94+
95+
def create( # type: ignore
96+
self,
97+
input: Dict[str, Any],
98+
webhook: Optional[str] = None,
99+
webhook_completed: Optional[str] = None,
100+
webhook_events_filter: Optional[List[str]] = None,
101+
*,
102+
stream: Optional[bool] = None,
103+
**kwargs,
104+
) -> Prediction:
105+
"""
106+
Create a new prediction with the deployment.
107+
108+
Args:
109+
input: The input data for the prediction.
110+
webhook: The URL to receive a POST request with prediction updates.
111+
webhook_completed: The URL to receive a POST request when the prediction is completed.
112+
webhook_events_filter: List of events to trigger webhooks.
113+
stream: Set to True to enable streaming of prediction output.
114+
115+
Returns:
116+
Prediction: The created prediction object.
117+
"""
118+
119+
input = encode_json(input, upload_file=upload_file)
120+
body = {
121+
"input": input,
122+
}
123+
if webhook is not None:
124+
body["webhook"] = webhook
125+
if webhook_completed is not None:
126+
body["webhook_completed"] = webhook_completed
127+
if webhook_events_filter is not None:
128+
body["webhook_events_filter"] = webhook_events_filter
129+
if stream is True:
130+
body["stream"] = "true"
131+
132+
resp = self._client._request(
133+
"POST",
134+
f"/v1/deployments/{self._deployment.username}/{self._deployment.name}/predictions",
135+
json=body,
136+
)
137+
obj = resp.json()
138+
obj["deployment"] = self._deployment
139+
del obj["version"]
140+
return self.prepare_model(obj)

tests/test_deployment.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import responses
2+
from responses import matchers
3+
4+
from replicate.client import Client
5+
6+
7+
@responses.activate
8+
def test_deployment_predictions_create():
9+
client = Client(api_token="abc123")
10+
11+
deployment = client.deployments.get("test/model")
12+
13+
rsp = responses.post(
14+
"https://api.replicate.com/v1/deployments/test/model/predictions",
15+
match=[
16+
matchers.json_params_matcher(
17+
{
18+
"input": {"text": "world"},
19+
"webhook": "https://example.com/webhook",
20+
"webhook_events_filter": ["completed"],
21+
}
22+
),
23+
],
24+
json={
25+
"id": "p1",
26+
"version": "v1",
27+
"urls": {
28+
"get": "https://api.replicate.com/v1/predictions/p1",
29+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
30+
},
31+
"created_at": "2022-04-26T20:00:40.658234Z",
32+
"source": "api",
33+
"status": "processing",
34+
"input": {"text": "hello"},
35+
"output": None,
36+
"error": None,
37+
"logs": "",
38+
},
39+
)
40+
41+
deployment.predictions.create(
42+
input={"text": "world"},
43+
webhook="https://example.com/webhook",
44+
webhook_events_filter=["completed"],
45+
)
46+
47+
assert rsp.call_count == 1

0 commit comments

Comments
 (0)