Skip to content

Rename BaseModel and Collection to Resource and Namespace #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions replicate/base_model.py

This file was deleted.

30 changes: 15 additions & 15 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import httpx

from replicate.__about__ import __version__
from replicate.deployment import DeploymentCollection
from replicate.deployment import Deployments
from replicate.exceptions import ModelError, ReplicateError
from replicate.hardware import HardwareCollection
from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
from replicate.hardware import Hardwares
from replicate.model import Models
from replicate.prediction import Predictions
from replicate.schema import make_schema_backwards_compatible
from replicate.training import TrainingCollection
from replicate.training import Trainings
from replicate.version import Version


Expand Down Expand Up @@ -85,39 +85,39 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
return resp

@property
def deployments(self) -> DeploymentCollection:
def deployments(self) -> Deployments:
"""
Namespace for operations related to deployments.
"""
return DeploymentCollection(client=self)
return Deployments(client=self)

@property
def hardware(self) -> HardwareCollection:
def hardware(self) -> Hardwares:
"""
Namespace for operations related to hardware.
"""
return HardwareCollection(client=self)
return Hardwares(client=self)

@property
def models(self) -> ModelCollection:
def models(self) -> Models:
"""
Namespace for operations related to models.
"""
return ModelCollection(client=self)
return Models(client=self)

@property
def predictions(self) -> PredictionCollection:
def predictions(self) -> Predictions:
"""
Namespace for operations related to predictions.
"""
return PredictionCollection(client=self)
return Predictions(client=self)

@property
def trainings(self) -> TrainingCollection:
def trainings(self) -> Trainings:
"""
Namespace for operations related to trainings.
"""
return TrainingCollection(client=self)
return Trainings(client=self)

def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Expand Down
17 changes: 8 additions & 9 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.prediction import Prediction
from replicate.resource import Namespace, Resource

if TYPE_CHECKING:
from replicate.client import Client


class Deployment(BaseModel):
class Deployment(Resource):
"""
A deployment of a model hosted on Replicate.
"""

_collection: "DeploymentCollection"
_namespace: "Deployments"

username: str
"""
Expand All @@ -28,15 +27,15 @@ class Deployment(BaseModel):
"""

@property
def predictions(self) -> "DeploymentPredictionCollection":
def predictions(self) -> "DeploymentPredictions":
"""
Get the predictions for this deployment.
"""

return DeploymentPredictionCollection(client=self._client, deployment=self)
return DeploymentPredictions(client=self._client, deployment=self)


class DeploymentCollection(Collection):
class Deployments(Namespace):
"""
Namespace for operations related to deployments.
"""
Expand All @@ -59,14 +58,14 @@ def get(self, name: str) -> Deployment:
return self._prepare_model({"username": username, "name": name})

def _prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
if isinstance(attrs, BaseModel):
if isinstance(attrs, Resource):
attrs.id = f"{attrs.username}/{attrs.name}"
elif isinstance(attrs, dict):
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
return super()._prepare_model(attrs)


class DeploymentPredictionCollection(Collection):
class DeploymentPredictions(Namespace):
"""
Namespace for operations related to predictions in a deployment.
"""
Expand Down
9 changes: 4 additions & 5 deletions replicate/hardware.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Dict, List, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.resource import Namespace, Resource


class Hardware(BaseModel):
class Hardware(Resource):
"""
Hardware for running a model on Replicate.
"""
Expand All @@ -20,7 +19,7 @@ class Hardware(BaseModel):
"""


class HardwareCollection(Collection):
class Hardwares(Namespace):
"""
Namespace for operations related to hardware.
"""
Expand All @@ -40,7 +39,7 @@ def list(self) -> List[Hardware]:
return [self._prepare_model(obj) for obj in hardware]

def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware:
if isinstance(attrs, BaseModel):
if isinstance(attrs, Resource):
attrs.id = attrs.sku
elif isinstance(attrs, dict):
attrs["id"] = attrs["sku"]
Expand Down
19 changes: 9 additions & 10 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@

from typing_extensions import deprecated

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ReplicateException
from replicate.prediction import Prediction
from replicate.version import Version, VersionCollection
from replicate.resource import Namespace, Resource
from replicate.version import Version, Versions


class Model(BaseModel):
class Model(Resource):
"""
A machine learning model hosted on Replicate.
"""

_collection: "ModelCollection"
_namespace: "Models"

url: str
"""
Expand Down Expand Up @@ -100,24 +99,24 @@ def predict(self, *args, **kwargs) -> None:
)

@property
def versions(self) -> VersionCollection:
def versions(self) -> Versions:
"""
Get the versions of this model.
"""

return VersionCollection(client=self._client, model=self)
return Versions(client=self._client, model=self)

def reload(self) -> None:
"""
Load this object from the server.
"""

obj = self._collection.get(f"{self.owner}/{self.name}") # pylint: disable=no-member
obj = self._namespace.get(f"{self.owner}/{self.name}") # pylint: disable=no-member
for name, value in obj.dict().items():
setattr(self, name, value)


class ModelCollection(Collection):
class Models(Namespace):
"""
Namespace for operations related to models.
"""
Expand Down Expand Up @@ -208,7 +207,7 @@ def create( # pylint: disable=arguments-differ disable=too-many-arguments
return self._prepare_model(resp.json())

def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
if isinstance(attrs, BaseModel):
if isinstance(attrs, Resource):
attrs.id = f"{attrs.owner}/{attrs.name}"
elif isinstance(attrs, dict):
attrs["id"] = f"{attrs['owner']}/{attrs['name']}"
Expand Down
11 changes: 5 additions & 6 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ModelError
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.resource import Namespace, Resource
from replicate.version import Version


class Prediction(BaseModel):
class Prediction(Resource):
"""
A prediction made by a model hosted on Replicate.
"""

_collection: "PredictionCollection"
_namespace: "Predictions"

id: str
"""The unique ID of the prediction."""
Expand Down Expand Up @@ -146,12 +145,12 @@ def reload(self) -> None:
Load this prediction from the server.
"""

obj = self._collection.get(self.id) # pylint: disable=no-member
obj = self._namespace.get(self.id) # pylint: disable=no-member
for name, value in obj.dict().items():
setattr(self, name, value)


class PredictionCollection(Collection):
class Predictions(Namespace):
"""
Namespace for operations related to predictions.
"""
Expand Down
30 changes: 23 additions & 7 deletions replicate/collection.py → replicate/resource.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
import abc
from typing import TYPE_CHECKING, Dict, Generic, TypeVar, Union, cast

from replicate.exceptions import ReplicateException

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore

if TYPE_CHECKING:
from replicate.client import Client

from replicate.base_model import BaseModel
from replicate.exceptions import ReplicateException

Model = TypeVar("Model", bound=BaseModel)
class Resource(pydantic.BaseModel):
"""
A base class for representing a single object on the server.
"""

id: str

_client: "Client" = pydantic.PrivateAttr()
_namespace: "Namespace" = pydantic.PrivateAttr()


Model = TypeVar("Model", bound=Resource)


class Collection(abc.ABC, Generic[Model]):
class Namespace(abc.ABC, Generic[Model]):
"""
A base class for representing objects of a particular type on the server.
"""
Expand All @@ -25,15 +41,15 @@ def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
"""
Create a model from a set of attributes.
"""
if isinstance(attrs, BaseModel):
if isinstance(attrs, Resource):
attrs._client = self._client
attrs._collection = self
attrs._namespace = self
return cast(Model, attrs)

if isinstance(attrs, dict) and self.model is not None and callable(self.model):
model = self.model(**attrs)
model._client = self._client
model._collection = self
model._namespace = self
return model

name = self.model.__name__ if hasattr(self.model, "__name__") else "model"
Expand Down
11 changes: 5 additions & 6 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@

from typing_extensions import NotRequired, Unpack, overload

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ReplicateException
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.resource import Namespace, Resource
from replicate.version import Version


class Training(BaseModel):
class Training(Resource):
"""
A training made for a model hosted on Replicate.
"""

_collection: "TrainingCollection"
_namespace: "Trainings"

id: str
"""The unique ID of the training."""
Expand Down Expand Up @@ -69,12 +68,12 @@ def reload(self) -> None:
Load the training from the server.
"""

obj = self._collection.get(self.id) # pylint: disable=no-member
obj = self._namespace.get(self.id) # pylint: disable=no-member
for name, value in obj.dict().items():
setattr(self, name, value)


class TrainingCollection(Collection):
class Trainings(Namespace):
"""
Namespace for operations related to trainings.
"""
Expand Down
Loading