Skip to content

Add support for collections.list endpoint #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,38 @@ urlretrieve(out[0], "/tmp/out.png")
background = Image.open("/tmp/out.png")
```

## List models

You can the models you've created:

```python
replicate.models.list()
```

Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method. Here's how you can get all the models you've created:

```python
models = []
page = replicate.models.list()

while page:
models.extend(page.results)
page = replicate.models.list(page.next) if page.next else None
```

You can also find collections of featured models on Replicate:

```python
>>> collections = replicate.collections.list()
>>> collections[0].slug
"vision-models"
>>> collections[0].description
"Multimodal large language models with vision capabilities like object detection and optical character recognition (OCR)"

>>> replicate.collections.get("text-to-image").models
[<Model: stability-ai/sdxl>, ...]
```

## Create a model

You can create a model for a user or organization
Expand Down
1 change: 1 addition & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

default_client = Client()
run = default_client.run
collections = default_client.collections
hardware = default_client.hardware
deployments = default_client.deployments
models = default_client.models
Expand Down
8 changes: 8 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import httpx

from replicate.__about__ import __version__
from replicate.collection import Collections
from replicate.deployment import Deployments
from replicate.exceptions import ModelError, ReplicateError
from replicate.hardware import Hardwares
Expand Down Expand Up @@ -84,6 +85,13 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:

return resp

@property
def collections(self) -> Collections:
"""
Namespace for operations related to collections of models.
"""
return Collections(client=self)

@property
def deployments(self) -> Deployments:
"""
Expand Down
105 changes: 105 additions & 0 deletions replicate/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from replicate.model import Model, Models
from replicate.pagination import Page
from replicate.resource import Namespace, Resource

if TYPE_CHECKING:
from replicate.client import Client


class Collection(Resource):
"""
A collection of models on Replicate.
"""

slug: str
"""The slug used to identify the collection."""

name: str
"""The name of the collection."""

description: str
"""A description of the collection."""

models: Optional[List[Model]] = None
"""The models in the collection."""

def __iter__(self): # noqa: ANN204
return iter(self.models)

def __getitem__(self, index) -> Optional[Model]:
if self.models is not None:
return self.models[index]

return None

def __len__(self) -> int:
if self.models is not None:
return len(self.models)

return 0


class Collections(Namespace):
"""
A namespace for operations related to collections of models.
"""

model = Collection

_models: Models

def __init__(self, client: "Client") -> None:
self._models = Models(client)
super().__init__(client)

def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Collection]: # noqa: F821
"""
List collections of models.

Parameters:
cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`.
Returns:
Page[Collection]: A page of of model collections.
Raises:
ValueError: If `cursor` is `None`.
"""

if cursor is None:
raise ValueError("cursor cannot be None")

resp = self._client._request(
"GET", "/v1/collections" if cursor is ... else cursor
)

return Page[Collection](self._client, self, **resp.json())

def get(self, slug: str) -> Collection:
"""Get a model by name.

Args:
name: The name of the model, in the format `owner/model-name`.
Returns:
The model.
"""

resp = self._client._request("GET", f"/v1/collections/{slug}")

return self._prepare_model(resp.json())

def _prepare_model(self, attrs: Union[Collection, Dict]) -> Collection:
if isinstance(attrs, Resource):
attrs.id = attrs.slug

if attrs.models is not None:
attrs.models = [self._models._prepare_model(m) for m in attrs.models]
elif isinstance(attrs, dict):
attrs["id"] = attrs["slug"]

if "models" in attrs:
attrs["models"] = [
self._models._prepare_model(m) for m in attrs["models"]
]

return super()._prepare_model(attrs)
7 changes: 2 additions & 5 deletions replicate/hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,12 @@ def list(self) -> List[Hardware]:
"""

resp = self._client._request("GET", "/v1/hardware")
hardware = resp.json()
return [self._prepare_model(obj) for obj in hardware]
return [self._prepare_model(obj) for obj in resp.json()]

def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware:
if isinstance(attrs, Resource):
attrs.id = attrs.sku
elif isinstance(attrs, dict):
attrs["id"] = attrs["sku"]

hardware = super()._prepare_model(attrs)

return hardware
return super()._prepare_model(attrs)
18,285 changes: 18,285 additions & 0 deletions tests/cassettes/collections-get.yaml

Large diffs are not rendered by default.

Loading