Skip to content

Commit 61922f8

Browse files
authored
Refactor Collection superclass (#186)
This PR removes abstract methods from base collection, leaving each collection to define its own `get` / `create` / `list` methods on its own terms. The primary benefit to having abstract methods defined was to reuse shared logic for `reload`, but that method is short and required additional maneuvering to define an `id` field. The downside of this tight coupling was that resources were required to implement methods that would raise `NotImplemented`. And their signatures for `create` methods were inflexible, requiring either a mismatch with the parent class or an elaborate workaround with overloads and unpacked typed dict kwargs. --------- Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 64c41c1 commit 61922f8

File tree

12 files changed

+127
-225
lines changed

12 files changed

+127
-225
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ disable = [
4646
"R0801", # Similar lines in N files
4747
"W0212", # Access to a protected member
4848
"W0622", # Redefining built-in
49+
"R0903", # Too few public methods
4950
]
5051

5152
[tool.ruff]

replicate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .client import Client
1+
from replicate.client import Client
22

33
default_client = Client()
44
run = default_client.run

replicate/base_model.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,3 @@ class BaseModel(pydantic.BaseModel):
1919

2020
_client: "Client" = pydantic.PrivateAttr()
2121
_collection: "Collection" = pydantic.PrivateAttr()
22-
23-
def reload(self) -> None:
24-
"""
25-
Load this object from the server again.
26-
"""
27-
28-
new_model = self._collection.get(self.id) # pylint: disable=no-member
29-
for k, v in new_model.dict().items(): # pylint: disable=invalid-name
30-
setattr(self, k, v)

replicate/client.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
import httpx
1616

17-
from .__about__ import __version__
18-
from .deployment import DeploymentCollection
19-
from .exceptions import ModelError, ReplicateError
20-
from .model import ModelCollection
21-
from .prediction import PredictionCollection
22-
from .training import TrainingCollection
23-
from .version import Version
17+
from replicate.__about__ import __version__
18+
from replicate.deployment import DeploymentCollection
19+
from replicate.exceptions import ModelError, ReplicateError
20+
from replicate.model import ModelCollection
21+
from replicate.prediction import PredictionCollection
22+
from replicate.schema import make_schema_backwards_compatible
23+
from replicate.training import TrainingCollection
24+
from replicate.version import Version
2425

2526

2627
class Client:
@@ -143,7 +144,9 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noq
143144
version = Version(**resp.json())
144145

145146
# Return an iterator of the output
146-
schema = version.get_transformed_schema()
147+
schema = make_schema_backwards_compatible(
148+
version.openapi_schema, version.cog_version
149+
)
147150
output = schema["components"]["schemas"]["Output"]
148151
if (
149152
output.get("type") == "array"
@@ -175,9 +178,10 @@ class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
175178
)
176179
MAX_BACKOFF_WAIT = 60
177180

178-
def __init__(
181+
def __init__( # pylint: disable=too-many-arguments
179182
self,
180183
wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport],
184+
*,
181185
max_attempts: int = 10,
182186
max_backoff_wait: float = MAX_BACKOFF_WAIT,
183187
backoff_factor: float = 0.1,

replicate/collection.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast
2+
from typing import TYPE_CHECKING, Dict, Generic, TypeVar, Union, cast
33

44
if TYPE_CHECKING:
55
from replicate.client import Client
@@ -15,29 +15,13 @@ class Collection(abc.ABC, Generic[Model]):
1515
A base class for representing objects of a particular type on the server.
1616
"""
1717

18+
_client: "Client"
19+
model: Model
20+
1821
def __init__(self, client: "Client") -> None:
1922
self._client = client
2023

21-
@property
22-
@abc.abstractmethod
23-
def model(self) -> Model: # pylint: disable=missing-function-docstring
24-
pass
25-
26-
@abc.abstractmethod
27-
def list(self) -> List[Model]: # pylint: disable=missing-function-docstring
28-
pass
29-
30-
@abc.abstractmethod
31-
def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring
32-
pass
33-
34-
@abc.abstractmethod
35-
def create( # pylint: disable=missing-function-docstring
36-
self, *args, **kwargs
37-
) -> Model:
38-
pass
39-
40-
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
24+
def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
4125
"""
4226
Create a model from a set of attributes.
4327
"""

replicate/deployment.py

Lines changed: 24 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload
2-
3-
from typing_extensions import Unpack
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
42

53
from replicate.base_model import BaseModel
64
from replicate.collection import Collection
75
from replicate.files import upload_file
86
from replicate.json import encode_json
9-
from replicate.prediction import Prediction, PredictionCollection
7+
from replicate.prediction import Prediction
108

119
if TYPE_CHECKING:
1210
from replicate.client import Client
@@ -17,6 +15,8 @@ class Deployment(BaseModel):
1715
A deployment of a model hosted on Replicate.
1816
"""
1917

18+
_collection: "DeploymentCollection"
19+
2020
username: str
2121
"""
2222
The name of the user or organization that owns the deployment.
@@ -43,15 +43,6 @@ class DeploymentCollection(Collection):
4343

4444
model = Deployment
4545

46-
def list(self) -> List[Deployment]:
47-
"""
48-
List deployments.
49-
50-
Raises:
51-
NotImplementedError: This method is not implemented.
52-
"""
53-
raise NotImplementedError()
54-
5546
def get(self, name: str) -> Deployment:
5647
"""
5748
Get a deployment by name.
@@ -65,89 +56,35 @@ def get(self, name: str) -> Deployment:
6556
# TODO: fetch model from server
6657
# TODO: support permanent IDs
6758
username, name = name.split("/")
68-
return self.prepare_model({"username": username, "name": name})
69-
70-
def create(
71-
self,
72-
*args,
73-
**kwargs,
74-
) -> Deployment:
75-
"""
76-
Create a deployment.
77-
78-
Raises:
79-
NotImplementedError: This method is not implemented.
80-
"""
81-
raise NotImplementedError()
59+
return self._prepare_model({"username": username, "name": name})
8260

83-
def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
61+
def _prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
8462
if isinstance(attrs, BaseModel):
8563
attrs.id = f"{attrs.username}/{attrs.name}"
8664
elif isinstance(attrs, dict):
8765
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
88-
return super().prepare_model(attrs)
66+
return super()._prepare_model(attrs)
8967

9068

9169
class DeploymentPredictionCollection(Collection):
70+
"""
71+
Namespace for operations related to predictions in a deployment.
72+
"""
73+
9274
model = Prediction
9375

9476
def __init__(self, client: "Client", deployment: Deployment) -> None:
9577
super().__init__(client=client)
9678
self._deployment = deployment
9779

98-
def list(self) -> List[Prediction]:
99-
"""
100-
List predictions in a deployment.
101-
102-
Raises:
103-
NotImplementedError: This method is not implemented.
104-
"""
105-
raise NotImplementedError()
106-
107-
def get(self, id: str) -> Prediction:
108-
"""
109-
Get a prediction by ID.
110-
111-
Args:
112-
id: The ID of the prediction.
113-
Returns:
114-
Prediction: The prediction object.
115-
"""
116-
117-
resp = self._client._request("GET", f"/v1/predictions/{id}")
118-
obj = resp.json()
119-
# HACK: resolve this? make it lazy somehow?
120-
del obj["version"]
121-
return self.prepare_model(obj)
122-
123-
@overload
124-
def create( # pylint: disable=arguments-differ disable=too-many-arguments
80+
def create(
12581
self,
12682
input: Dict[str, Any],
12783
*,
12884
webhook: Optional[str] = None,
12985
webhook_completed: Optional[str] = None,
13086
webhook_events_filter: Optional[List[str]] = None,
13187
stream: Optional[bool] = None,
132-
) -> Prediction:
133-
...
134-
135-
@overload
136-
def create( # pylint: disable=arguments-differ disable=too-many-arguments
137-
self,
138-
*,
139-
input: Dict[str, Any],
140-
webhook: Optional[str] = None,
141-
webhook_completed: Optional[str] = None,
142-
webhook_events_filter: Optional[List[str]] = None,
143-
stream: Optional[bool] = None,
144-
) -> Prediction:
145-
...
146-
147-
def create(
148-
self,
149-
*args,
150-
**kwargs: Unpack[PredictionCollection.CreateParams], # type: ignore[misc]
15188
) -> Prediction:
15289
"""
15390
Create a new prediction with the deployment.
@@ -163,20 +100,21 @@ def create(
163100
Prediction: The created prediction object.
164101
"""
165102

166-
input = args[0] if len(args) > 0 else kwargs.get("input")
167-
if input is None:
168-
raise ValueError(
169-
"An input must be provided as a positional or keyword argument."
170-
)
171-
172103
body = {
173104
"input": encode_json(input, upload_file=upload_file),
174105
}
175106

176-
for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]:
177-
value = kwargs.get(key)
178-
if value is not None:
179-
body[key] = value
107+
if webhook is not None:
108+
body["webhook"] = webhook
109+
110+
if webhook_completed is not None:
111+
body["webhook_completed"] = webhook_completed
112+
113+
if webhook_events_filter is not None:
114+
body["webhook_events_filter"] = webhook_events_filter
115+
116+
if stream is not None:
117+
body["stream"] = stream
180118

181119
resp = self._client._request(
182120
"POST",
@@ -186,4 +124,4 @@ def create(
186124
obj = resp.json()
187125
obj["deployment"] = self._deployment
188126
del obj["version"]
189-
return self.prepare_model(obj)
127+
return self._prepare_model(obj)

replicate/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
class ReplicateException(Exception):
2-
pass
2+
"""A base class for all Replicate exceptions."""
33

44

55
class ModelError(ReplicateException):

replicate/model.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class Model(BaseModel):
1414
A machine learning model hosted on Replicate.
1515
"""
1616

17+
_collection: "ModelCollection"
18+
1719
url: str
1820
"""
1921
The URL of the model.
@@ -105,6 +107,15 @@ def versions(self) -> VersionCollection:
105107

106108
return VersionCollection(client=self._client, model=self)
107109

110+
def reload(self) -> None:
111+
"""
112+
Load this object from the server.
113+
"""
114+
115+
obj = self._collection.get(f"{self.owner}/{self.name}") # pylint: disable=no-member
116+
for name, value in obj.dict().items():
117+
setattr(self, name, value)
118+
108119

109120
class ModelCollection(Collection):
110121
"""
@@ -124,7 +135,7 @@ def list(self) -> List[Model]:
124135
resp = self._client._request("GET", "/v1/models")
125136
# TODO: paginate
126137
models = resp.json()["results"]
127-
return [self.prepare_model(obj) for obj in models]
138+
return [self._prepare_model(obj) for obj in models]
128139

129140
def get(self, key: str) -> Model:
130141
"""
@@ -137,22 +148,9 @@ def get(self, key: str) -> Model:
137148
"""
138149

139150
resp = self._client._request("GET", f"/v1/models/{key}")
140-
return self.prepare_model(resp.json())
141-
142-
def create(
143-
self,
144-
*args,
145-
**kwargs,
146-
) -> Model:
147-
"""
148-
Create a model.
149-
150-
Raises:
151-
NotImplementedError: This method is not implemented.
152-
"""
153-
raise NotImplementedError()
151+
return self._prepare_model(resp.json())
154152

155-
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
153+
def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
156154
if isinstance(attrs, BaseModel):
157155
attrs.id = f"{attrs.owner}/{attrs.name}"
158156
elif isinstance(attrs, dict):
@@ -165,7 +163,7 @@ def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
165163
if "latest_version" in attrs and attrs["latest_version"] == {}:
166164
attrs.pop("latest_version")
167165

168-
model = super().prepare_model(attrs)
166+
model = super()._prepare_model(attrs)
169167

170168
if model.default_example is not None:
171169
model.default_example._client = self._client

0 commit comments

Comments
 (0)