Skip to content

Commit db27f92

Browse files
committed
Set default client poll_interval in test_run
Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 430ede7 commit db27f92

File tree

5 files changed

+147
-4
lines changed

5 files changed

+147
-4
lines changed

replicate/model.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List, Optional, TypedDict, Union
22

3-
from typing_extensions import NotRequired, Unpack, deprecated
3+
from typing_extensions import NotRequired, Unpack, deprecated, overload
44

55
from replicate.base_model import BaseModel
66
from replicate.collection import Collection
@@ -118,7 +118,7 @@ class CreateParams(TypedDict): # pylint: disable=too-many-ancestors
118118
name: str
119119
visibility: str
120120
hardware: str
121-
description: str
121+
description: NotRequired[str]
122122
github_url: NotRequired[str]
123123
paper_url: NotRequired[str]
124124
license_url: NotRequired[str]
@@ -152,6 +152,38 @@ def get(self, key: str) -> Model:
152152
resp = self._client._request("GET", f"/v1/models/{key}")
153153
return self.prepare_model(resp.json())
154154

155+
@overload
156+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
157+
self,
158+
owner: str,
159+
name: str,
160+
visibility: str,
161+
hardware: str,
162+
*,
163+
description: Optional[str] = None,
164+
github_url: Optional[str] = None,
165+
paper_url: Optional[str] = None,
166+
license_url: Optional[str] = None,
167+
cover_image_url: Optional[str] = None,
168+
) -> Model:
169+
...
170+
171+
@overload
172+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
173+
self,
174+
*,
175+
owner: str,
176+
name: str,
177+
visibility: str,
178+
hardware: str,
179+
description: Optional[str] = None,
180+
github_url: Optional[str] = None,
181+
paper_url: Optional[str] = None,
182+
license_url: Optional[str] = None,
183+
cover_image_url: Optional[str] = None,
184+
) -> Model:
185+
...
186+
155187
def create(
156188
self,
157189
*args,

tests/cassettes/models-create.yaml

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
interactions:
2+
- request:
3+
body: '{"owner": "test", "name": "python-example", "visibility": "private", "hardware":
4+
"cpu", "description": "An example model"}'
5+
headers:
6+
accept:
7+
- '*/*'
8+
accept-encoding:
9+
- gzip, deflate
10+
connection:
11+
- keep-alive
12+
content-length:
13+
- '123'
14+
content-type:
15+
- application/json
16+
host:
17+
- api.replicate.com
18+
user-agent:
19+
- replicate-python/0.15.6
20+
method: POST
21+
uri: https://api.replicate.com/v1/models
22+
response:
23+
content: '{"url": "https://replicate.com/test/python-example", "owner": "test",
24+
"name": "python-example", "description": "An example model", "visibility": "private",
25+
"github_url": null, "paper_url": null, "license_url": null, "run_count": 0,
26+
"cover_image_url": null, "default_example": null, "latest_version": null}'
27+
headers:
28+
CF-Cache-Status:
29+
- DYNAMIC
30+
CF-RAY:
31+
- 81ff2e098ec0eb5b-SEA
32+
Connection:
33+
- keep-alive
34+
Content-Length:
35+
- '307'
36+
Content-Type:
37+
- application/json
38+
Date:
39+
- Thu, 02 Nov 2023 20:38:12 GMT
40+
Server:
41+
- cloudflare
42+
Strict-Transport-Security:
43+
- max-age=15552000
44+
allow:
45+
- GET, POST, HEAD, OPTIONS
46+
content-security-policy-report-only:
47+
- 'font-src ''report-sample'' ''self'' data:; img-src ''report-sample'' ''self''
48+
data: https://replicate.delivery https://*.replicate.delivery https://*.githubusercontent.com
49+
https://github.com; script-src ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js;
50+
style-src ''report-sample'' ''self'' ''unsafe-inline''; connect-src ''report-sample''
51+
''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com
52+
https://*.rudderstack.com https://*.mux.com https://*.sentry.io; worker-src
53+
''none''; media-src ''report-sample'' ''self'' https://replicate.delivery
54+
https://*.replicate.delivery https://*.mux.com https://*.sentry.io; default-src
55+
''self''; report-uri'
56+
cross-origin-opener-policy:
57+
- same-origin
58+
nel:
59+
- '{"report_to":"heroku-nel","max_age":3600,"success_fraction":0.005,"failure_fraction":0.05,"response_headers":["Via"]}'
60+
ratelimit-remaining:
61+
- '2999'
62+
ratelimit-reset:
63+
- '1'
64+
referrer-policy:
65+
- same-origin
66+
report-to:
67+
- '{"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D"}]}'
68+
reporting-endpoints:
69+
- heroku-nel=https://nel.heroku.com/reports?ts=1698957492&sid=1b10b0ff-8a76-4548-befa-353fc6c6c045&s=m%2Fs583uNWdN4J4bm1G3JZoilUVMbh89egg%2FAEcTPZm4%3D
70+
vary:
71+
- Cookie, origin
72+
via:
73+
- 1.1 vegur, 1.1 google
74+
x-content-type-options:
75+
- nosniff
76+
x-frame-options:
77+
- DENY
78+
http_version: HTTP/1.1
79+
status_code: 201
80+
version: 1

tests/test_hardware.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import httpx
21
import pytest
3-
import respx
42

53
import replicate
64

tests/test_model.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,34 @@ async def test_models_list(mock_replicate_api_token):
2727
assert models[0].owner is not None
2828
assert models[0].name is not None
2929
assert models[0].visibility == "public"
30+
31+
32+
@pytest.mark.vcr("models-create.yaml")
33+
@pytest.mark.asyncio
34+
async def test_models_create(mock_replicate_api_token):
35+
model = replicate.models.create(
36+
owner="test",
37+
name="python-example",
38+
visibility="private",
39+
hardware="cpu",
40+
description="An example model",
41+
)
42+
43+
assert model.owner == "test"
44+
assert model.name == "python-example"
45+
assert model.visibility == "private"
46+
47+
48+
@pytest.mark.vcr("models-create.yaml")
49+
@pytest.mark.asyncio
50+
async def test_models_create_with_positional_arguments(mock_replicate_api_token):
51+
model = replicate.models.create(
52+
"test",
53+
"python-example",
54+
visibility="private",
55+
hardware="cpu",
56+
)
57+
58+
assert model.owner == "test"
59+
assert model.name == "python-example"
60+
assert model.visibility == "private"

tests/test_run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
@pytest.mark.vcr("run.yaml")
1111
@pytest.mark.asyncio
1212
async def test_run(mock_replicate_api_token):
13+
replicate.default_client.poll_interval = 0.0
14+
1315
version = "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5"
1416

1517
input = {

0 commit comments

Comments
 (0)