Skip to content

Commit 9bfa08a

Browse files
authored
Add support for models.predictions.create endpoint (#207)
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 5f7ae72 commit 9bfa08a

File tree

7 files changed

+193
-5
lines changed

7 files changed

+193
-5
lines changed

.github/workflows/ci.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ jobs:
1313

1414
name: "Test Python ${{ matrix.python-version }}"
1515

16+
timeout-minutes: 10
17+
1618
strategy:
1719
fail-fast: false
1820
matrix:

replicate/identifier.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,29 @@
22
from typing import NamedTuple
33

44

5+
class ModelIdentifier(NamedTuple):
6+
"""
7+
A reference to a model in the format owner/name:version.
8+
"""
9+
10+
owner: str
11+
name: str
12+
13+
@classmethod
14+
def parse(cls, ref: str) -> "ModelIdentifier":
15+
"""
16+
Split a reference in the format owner/name:version into its components.
17+
"""
18+
19+
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+)$", ref)
20+
if not match:
21+
raise ValueError(
22+
f"Invalid reference to model version: {ref}. Expected format: owner/name"
23+
)
24+
25+
return cls(match.group("owner"), match.group("name"))
26+
27+
528
class ModelVersionIdentifier(NamedTuple):
629
"""
730
A reference to a model version in the format owner/name:version.

replicate/model.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
1+
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
22

33
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
44

55
from replicate.exceptions import ReplicateException
6+
from replicate.identifier import ModelIdentifier
67
from replicate.pagination import Page
7-
from replicate.prediction import Prediction
8+
from replicate.prediction import (
9+
Prediction,
10+
_create_prediction_body,
11+
_json_to_prediction,
12+
)
813
from replicate.resource import Namespace, Resource
914
from replicate.version import Version, Versions
1015

@@ -16,6 +21,7 @@
1621

1722
if TYPE_CHECKING:
1823
from replicate.client import Client
24+
from replicate.prediction import Predictions
1925

2026

2127
class Model(Resource):
@@ -140,6 +146,14 @@ class Models(Namespace):
140146

141147
model = Model
142148

149+
@property
150+
def predictions(self) -> "ModelsPredictions":
151+
"""
152+
Get a namespace for operations related to predictions on a model.
153+
"""
154+
155+
return ModelsPredictions(client=self._client)
156+
143157
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821
144158
"""
145159
List all public models.
@@ -275,6 +289,54 @@ async def async_create(
275289
return _json_to_model(self._client, resp.json())
276290

277291

292+
class ModelsPredictions(Namespace):
293+
"""
294+
Namespace for operations related to predictions in a deployment.
295+
"""
296+
297+
def create(
298+
self,
299+
model: Optional[Union[str, Tuple[str, str], "Model"]],
300+
input: Dict[str, Any],
301+
**params: Unpack["Predictions.CreatePredictionParams"],
302+
) -> Prediction:
303+
"""
304+
Create a new prediction with the deployment.
305+
"""
306+
307+
url = _create_prediction_url_from_model(model)
308+
body = _create_prediction_body(version=None, input=input, **params)
309+
310+
resp = self._client._request(
311+
"POST",
312+
url,
313+
json=body,
314+
)
315+
316+
return _json_to_prediction(self._client, resp.json())
317+
318+
async def async_create(
319+
self,
320+
model: Optional[Union[str, Tuple[str, str], "Model"]],
321+
input: Dict[str, Any],
322+
**params: Unpack["Predictions.CreatePredictionParams"],
323+
) -> Prediction:
324+
"""
325+
Create a new prediction with the deployment.
326+
"""
327+
328+
url = _create_prediction_url_from_model(model)
329+
body = _create_prediction_body(version=None, input=input, **params)
330+
331+
resp = await self._client._async_request(
332+
"POST",
333+
url,
334+
json=body,
335+
)
336+
337+
return _json_to_prediction(self._client, resp.json())
338+
339+
278340
def _create_model_body( # pylint: disable=too-many-arguments
279341
owner: str,
280342
name: str,
@@ -318,3 +380,22 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
318380
if model.default_example is not None:
319381
model.default_example._client = client
320382
return model
383+
384+
385+
def _create_prediction_url_from_model(
386+
model: Union[str, Tuple[str, str], "Model"]
387+
) -> str:
388+
owner, name = None, None
389+
if isinstance(model, Model):
390+
owner, name = model.owner, model.name
391+
elif isinstance(model, tuple):
392+
owner, name = model[0], model[1]
393+
elif isinstance(model, str):
394+
owner, name = ModelIdentifier.parse(model)
395+
396+
if owner is None or name is None:
397+
raise ValueError(
398+
"model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'"
399+
)
400+
401+
return f"/v1/models/{owner}/{name}/predictions"

replicate/training.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing_extensions import NotRequired, Unpack
1414

1515
from replicate.files import upload_file
16-
from replicate.identifier import ModelVersionIdentifier
16+
from replicate.identifier import ModelIdentifier, ModelVersionIdentifier
1717
from replicate.json import encode_json
1818
from replicate.model import Model
1919
from replicate.pagination import Page
@@ -378,8 +378,12 @@ def _create_training_url_from_model_and_version(
378378
owner, name = model.owner, model.name
379379
elif isinstance(model, tuple):
380380
owner, name = model[0], model[1]
381+
elif isinstance(model, str):
382+
owner, name = ModelIdentifier.parse(model)
381383
else:
382-
raise ValueError("model must be a Model or a tuple of (owner, name)")
384+
raise ValueError(
385+
"model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'"
386+
)
383387

384388
if isinstance(version, Version):
385389
version_id = version.id

replicate/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Versions(Namespace):
3535
model: Tuple[str, str]
3636

3737
def __init__(
38-
self, client: "Client", model: Union["Model", str, Tuple[str, str]]
38+
self, client: "Client", model: Union[str, Tuple[str, str], "Model"]
3939
) -> None:
4040
super().__init__(client=client)
4141

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
interactions:
2+
- request:
3+
body: '{"input": {"prompt": "Please write a haiku about llamas."}}'
4+
headers:
5+
accept:
6+
- '*/*'
7+
accept-encoding:
8+
- gzip, deflate
9+
connection:
10+
- keep-alive
11+
content-length:
12+
- '59'
13+
content-type:
14+
- application/json
15+
host:
16+
- api.replicate.com
17+
user-agent:
18+
- replicate-python/0.21.0
19+
method: POST
20+
uri: https://api.replicate.com/v1/models/meta/llama-2-70b-chat/predictions
21+
response:
22+
content: '{"id":"heat2o3bzn3ahtr6bjfftvbaci","model":"replicate/lifeboat-70b","version":"d-c6559c5791b50af57b69f4a73f8e021c","input":{"prompt":"Please
23+
write a haiku about llamas."},"logs":"","error":null,"status":"starting","created_at":"2023-11-27T13:35:45.99397566Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel","get":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"}}
24+
25+
'
26+
headers:
27+
CF-Cache-Status:
28+
- DYNAMIC
29+
CF-RAY:
30+
- 82cac197efaec53d-SEA
31+
Connection:
32+
- keep-alive
33+
Content-Length:
34+
- '431'
35+
Content-Type:
36+
- application/json
37+
Date:
38+
- Mon, 27 Nov 2023 13:35:46 GMT
39+
NEL:
40+
- '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}'
41+
Report-To:
42+
- '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=7R5RONMF6xaGRc39n0wnSe3jU1FbpX64Xz4U%2B%2F2nasvFaz0pKARxPhnzDgYkLaWgdK9zWrD2jxU04aKOy5HMPHAXboJ993L4zfsOyto56lBtdqSjNgkptzzxYEsKD%2FxIhe2F"}],"group":"cf-nel","max_age":604800}'
43+
Server:
44+
- cloudflare
45+
Strict-Transport-Security:
46+
- max-age=15552000
47+
ratelimit-remaining:
48+
- '599'
49+
ratelimit-reset:
50+
- '1'
51+
via:
52+
- 1.1 google
53+
http_version: HTTP/1.1
54+
status_code: 201
55+
version: 1

tests/test_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,26 @@ async def test_models_create_with_positional_arguments(async_flag):
107107
assert model.owner == "test"
108108
assert model.name == "python-example"
109109
assert model.visibility == "private"
110+
111+
112+
@pytest.mark.vcr("models-predictions-create.yaml")
113+
@pytest.mark.asyncio
114+
@pytest.mark.parametrize("async_flag", [True, False])
115+
async def test_models_predictions_create(async_flag):
116+
input = {
117+
"prompt": "Please write a haiku about llamas.",
118+
}
119+
120+
if async_flag:
121+
prediction = await replicate.models.predictions.async_create(
122+
"meta/llama-2-70b-chat", input=input
123+
)
124+
else:
125+
prediction = replicate.models.predictions.create(
126+
"meta/llama-2-70b-chat", input=input
127+
)
128+
129+
assert prediction.id is not None
130+
# assert prediction.model == "meta/llama-2-70b-chat"
131+
assert prediction.model == "replicate/lifeboat-70b" # FIXME: this is temporary
132+
assert prediction.status == "starting"

0 commit comments

Comments
 (0)