File tree Expand file tree Collapse file tree 2 files changed +64
-0
lines changed Expand file tree Collapse file tree 2 files changed +64
-0
lines changed Original file line number Diff line number Diff line change @@ -207,6 +207,46 @@ async def async_list(
207
207
208
208
return Page [Model ](** obj )
209
209
210
+ def search (self , query : str ) -> Page [Model ]:
211
+ """
212
+ Search for public models.
213
+
214
+ Parameters:
215
+ query: The search query.
216
+ Returns:
217
+ Page[Model]: A page of models matching the search query.
218
+ """
219
+ resp = self ._client ._request (
220
+ "QUERY" , "/v1/models" , content = query , headers = {"Content-Type" : "text/plain" }
221
+ )
222
+
223
+ obj = resp .json ()
224
+ obj ["results" ] = [
225
+ _json_to_model (self ._client , result ) for result in obj ["results" ]
226
+ ]
227
+
228
+ return Page [Model ](** obj )
229
+
230
+ async def async_search (self , query : str ) -> Page [Model ]:
231
+ """
232
+ Asynchronously search for public models.
233
+
234
+ Parameters:
235
+ query: The search query.
236
+ Returns:
237
+ Page[Model]: A page of models matching the search query.
238
+ """
239
+ resp = await self ._client ._async_request (
240
+ "QUERY" , "/v1/models" , content = query , headers = {"Content-Type" : "text/plain" }
241
+ )
242
+
243
+ obj = resp .json ()
244
+ obj ["results" ] = [
245
+ _json_to_model (self ._client , result ) for result in obj ["results" ]
246
+ ]
247
+
248
+ return Page [Model ](** obj )
249
+
210
250
@overload
211
251
def get (self , key : str ) -> Model : ...
212
252
Original file line number Diff line number Diff line change 1
1
import pytest
2
2
3
3
import replicate
4
+ from replicate .model import Model , Page
4
5
5
6
6
7
@pytest .mark .vcr ("models-get.yaml" )
@@ -130,3 +131,26 @@ async def test_models_predictions_create(async_flag):
130
131
# assert prediction.model == "meta/llama-2-70b-chat"
131
132
assert prediction .model == "replicate/lifeboat-70b" # FIXME: this is temporary
132
133
assert prediction .status == "starting"
134
+
135
+
136
+ @pytest .mark .vcr ("models-search.yaml" )
137
+ @pytest .mark .asyncio
138
+ @pytest .mark .parametrize ("async_flag" , [True , False ])
139
+ async def test_models_search (async_flag ):
140
+ query = "llama"
141
+
142
+ if async_flag :
143
+ page = await replicate .models .async_search (query )
144
+ else :
145
+ page = replicate .models .search (query )
146
+
147
+ assert isinstance (page , Page )
148
+ assert len (page .results ) > 0
149
+
150
+ for model in page .results :
151
+ assert isinstance (model , Model )
152
+ assert model .id is not None
153
+ assert model .owner is not None
154
+ assert model .name is not None
155
+
156
+ assert any ("meta" in model .name .lower () for model in page .results )
You can’t perform that action at this time.
0 commit comments