Skip to content

Commit 39f28eb

Browse files
committed
Add support for models.search endpoint
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent a47f276 commit 39f28eb

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

replicate/model.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,46 @@ async def async_list(
207207

208208
return Page[Model](**obj)
209209

210+
def search(self, query: str) -> Page[Model]:
211+
"""
212+
Search for public models.
213+
214+
Parameters:
215+
query: The search query.
216+
Returns:
217+
Page[Model]: A page of models matching the search query.
218+
"""
219+
resp = self._client._request(
220+
"QUERY", "/v1/models", content=query, headers={"Content-Type": "text/plain"}
221+
)
222+
223+
obj = resp.json()
224+
obj["results"] = [
225+
_json_to_model(self._client, result) for result in obj["results"]
226+
]
227+
228+
return Page[Model](**obj)
229+
230+
async def async_search(self, query: str) -> Page[Model]:
231+
"""
232+
Asynchronously search for public models.
233+
234+
Parameters:
235+
query: The search query.
236+
Returns:
237+
Page[Model]: A page of models matching the search query.
238+
"""
239+
resp = await self._client._async_request(
240+
"QUERY", "/v1/models", content=query, headers={"Content-Type": "text/plain"}
241+
)
242+
243+
obj = resp.json()
244+
obj["results"] = [
245+
_json_to_model(self._client, result) for result in obj["results"]
246+
]
247+
248+
return Page[Model](**obj)
249+
210250
@overload
211251
def get(self, key: str) -> Model: ...
212252

tests/test_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import replicate
4+
from replicate.model import Model, Page
45

56

67
@pytest.mark.vcr("models-get.yaml")
@@ -130,3 +131,26 @@ async def test_models_predictions_create(async_flag):
130131
# assert prediction.model == "meta/llama-2-70b-chat"
131132
assert prediction.model == "replicate/lifeboat-70b" # FIXME: this is temporary
132133
assert prediction.status == "starting"
134+
135+
136+
@pytest.mark.vcr("models-search.yaml")
137+
@pytest.mark.asyncio
138+
@pytest.mark.parametrize("async_flag", [True, False])
139+
async def test_models_search(async_flag):
140+
query = "llama"
141+
142+
if async_flag:
143+
page = await replicate.models.async_search(query)
144+
else:
145+
page = replicate.models.search(query)
146+
147+
assert isinstance(page, Page)
148+
assert len(page.results) > 0
149+
150+
for model in page.results:
151+
assert isinstance(model, Model)
152+
assert model.id is not None
153+
assert model.owner is not None
154+
assert model.name is not None
155+
156+
assert any("meta" in model.name.lower() for model in page.results)

0 commit comments

Comments
 (0)