Skip to content

Commit afe0092

Browse files
authored
Type check Python source code with ruff and mypy (#87)
* Enable flake8-annotations check Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Fix annotations and type checking errors Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Ignore ANN001 in tests Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Fix remaining annotation warnings Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Add missing id property to BaseModel Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Fix type checking for collection model Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Add mypy as development dependency Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Ignore type checking for overloaded create methods Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Run black and mypy in lint step Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Install mypy types in CI workflow Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Use List and Union from typing Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Fix CI setup for mypy Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Fix pydantic validation error due to missing id field in model Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Fix urllib3 DeprecationWarning about using 'method_whitelist' with Retry Signed-off-by: Mattt Zmuda <mattt@replicate.com> * error: "dict" is not subscriptable, use "typing.Dict" instead Signed-off-by: Mattt Zmuda <mattt@replicate.com> --------- Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent c892c78 commit afe0092

File tree

13 files changed

+158
-77
lines changed

13 files changed

+158
-77
lines changed

.github/workflows/ci.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ jobs:
2323
python-version: ${{ matrix.python-version }}
2424
cache: "pip"
2525
- name: Install dependencies
26-
run: python -m pip install -r requirements.txt -r requirements-dev.txt .
26+
run: |
27+
python -m pip install -r requirements.txt -r requirements-dev.txt .
28+
yes | python -m mypy --install-types replicate || true
29+
2730
- name: Lint
28-
run: python -m ruff .
31+
run: |
32+
python -m mypy replicate
33+
python -m ruff .
34+
python -m black --check .
2935
- name: Test
3036
run: python -m pytest
3137

pyproject.toml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@ license = { file = "LICENSE" }
1111
authors = [{ name = "Replicate, Inc." }]
1212
requires-python = ">=3.8"
1313
dependencies = ["packaging", "pydantic>1", "requests>2"]
14-
optional-dependencies = { dev = ["black", "pytest", "responses", "ruff"] }
14+
optional-dependencies = { dev = [
15+
"black",
16+
"mypy",
17+
"pytest",
18+
"responses",
19+
"ruff",
20+
] }
1521

1622
[project.urls]
1723
homepage = "https://replicate.com"
@@ -31,11 +37,21 @@ select = [
3137
"BLE", # flake8-blind-except
3238
"FBT", # flake8-boolean-trap
3339
"B", # flake8-bugbear
40+
"ANN", # flake8-annotations
3441
]
3542
ignore = [
36-
"E501", # Line too long
37-
"S113", # Probable use of requests call without timeout
43+
"E501", # Line too long
44+
"S113", # Probable use of requests call without timeout
45+
"ANN001", # Missing type annotation for function argument
46+
"ANN002", # Missing type annotation for `*args`
47+
"ANN003", # Missing type annotation for `**kwargs`
48+
"ANN101", # Missing type annotation for self in method
49+
"ANN102", # Missing type annotation for cls in classmethod
3850
]
3951

4052
[tool.ruff.per-file-ignores]
41-
"tests/*" = ["S101", "S106"]
53+
"tests/*" = [
54+
"S101", # Use of assert
55+
"S106", # Possible use of hard-coded password function arguments
56+
"ANN201", # Missing return type annotation for public function
57+
]

replicate/base_model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
from typing import ForwardRef
1+
from typing import TYPE_CHECKING
22

3-
import pydantic
3+
if TYPE_CHECKING:
4+
from replicate.client import Client
5+
from replicate.collection import Collection
46

5-
Client = ForwardRef("Client")
6-
Collection = ForwardRef("Collection")
7+
import pydantic
78

89

910
class BaseModel(pydantic.BaseModel):
1011
"""
1112
A base class for representing a single object on the server.
1213
"""
1314

14-
_client: Client = pydantic.PrivateAttr()
15-
_collection: Collection = pydantic.PrivateAttr()
15+
id: str
16+
17+
_client: "Client" = pydantic.PrivateAttr()
18+
_collection: "Collection" = pydantic.PrivateAttr()
1619

17-
def reload(self):
20+
def reload(self) -> None:
1821
"""
1922
Load this object from the server again.
2023
"""

replicate/client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import re
33
from json import JSONDecodeError
4-
from typing import Any, Iterator, Union
4+
from typing import Any, Dict, Iterator, Optional, Union
55

66
import requests
77
from requests.adapters import HTTPAdapter, Retry
@@ -14,7 +14,7 @@
1414

1515

1616
class Client:
17-
def __init__(self, api_token=None) -> None:
17+
def __init__(self, api_token: Optional[str] = None) -> None:
1818
super().__init__()
1919
# Client is instantiated at import time, so do as little as possible.
2020
# This includes resolving environment variables -- they might be set programmatically.
@@ -30,7 +30,7 @@ def __init__(self, api_token=None) -> None:
3030
total=5,
3131
backoff_factor=2,
3232
# Only retry 500s on GET so we don't unintionally mutute data
33-
method_whitelist=["GET"],
33+
allowed_methods=["GET"],
3434
# https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors
3535
status_forcelist=[
3636
429,
@@ -54,14 +54,14 @@ def __init__(self, api_token=None) -> None:
5454
write_retries = Retry(
5555
total=5,
5656
backoff_factor=2,
57-
method_whitelist=["POST", "PUT"],
57+
allowed_methods=["POST", "PUT"],
5858
# Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data
5959
status_forcelist=[429],
6060
)
6161
self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries))
6262
self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries))
6363

64-
def _request(self, method: str, path: str, **kwargs):
64+
def _request(self, method: str, path: str, **kwargs) -> requests.Response:
6565
# from requests.Session
6666
if method in ["GET", "OPTIONS"]:
6767
kwargs.setdefault("allow_redirects", True)
@@ -81,13 +81,13 @@ def _request(self, method: str, path: str, **kwargs):
8181
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
8282
return resp
8383

84-
def _headers(self):
84+
def _headers(self) -> Dict[str, str]:
8585
return {
8686
"Authorization": f"Token {self._api_token()}",
8787
"User-Agent": f"replicate-python@{__version__}",
8888
}
8989

90-
def _api_token(self):
90+
def _api_token(self) -> str:
9191
token = self.api_token
9292
# Evaluate lazily in case environment variable is set with dotenv, or something
9393
if token is None:
@@ -112,7 +112,7 @@ def predictions(self) -> PredictionCollection:
112112
def trainings(self) -> TrainingCollection:
113113
return TrainingCollection(client=self)
114114

115-
def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]:
115+
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
116116
"""
117117
Run a model in the format owner/name:version.
118118
"""

replicate/collection.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,54 @@
1+
import abc
2+
from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast
3+
4+
if TYPE_CHECKING:
5+
from replicate.client import Client
6+
17
from replicate.base_model import BaseModel
28

9+
Model = TypeVar("Model", bound=BaseModel)
10+
311

4-
class Collection:
12+
class Collection(abc.ABC, Generic[Model]):
513
"""
614
A base class for representing all objects of a particular type on the
715
server.
816
"""
917

10-
model: BaseModel = None
11-
12-
def __init__(self, client=None):
18+
def __init__(self, client: "Client") -> None:
1319
self._client = client
1420

15-
def list(self):
16-
raise NotImplementedError
21+
@abc.abstractproperty
22+
def model(self) -> Model:
23+
pass
24+
25+
@abc.abstractmethod
26+
def list(self) -> List[Model]:
27+
pass
1728

18-
def get(self, key):
19-
raise NotImplementedError
29+
@abc.abstractmethod
30+
def get(self, key: str) -> Model:
31+
pass
2032

21-
def create(self, attrs=None):
22-
raise NotImplementedError
33+
@abc.abstractmethod
34+
def create(self, **kwargs) -> Model:
35+
pass
2336

24-
def prepare_model(self, attrs):
37+
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
2538
"""
2639
Create a model from a set of attributes.
2740
"""
2841
if isinstance(attrs, BaseModel):
2942
attrs._client = self._client
3043
attrs._collection = self
31-
return attrs
32-
elif isinstance(attrs, dict):
44+
return cast(Model, attrs)
45+
elif (
46+
isinstance(attrs, dict) and self.model is not None and callable(self.model)
47+
):
3348
model = self.model(**attrs)
3449
model._client = self._client
3550
model._collection = self
3651
return model
3752
else:
38-
raise Exception(f"Can't create {self.model.__name__} from {attrs}")
53+
name = self.model.__name__ if hasattr(self.model, "__name__") else "model"
54+
raise Exception(f"Can't create {name} from {attrs}")

replicate/files.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import io
33
import mimetypes
44
import os
5+
from typing import Optional
56

67
import requests
78

89

9-
def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
10+
def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
1011
"""
1112
Lifted straight from cog.files
1213
"""

replicate/json.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
has_numpy = False
1212

1313

14-
def encode_json(obj: Any, upload_file: Callable[[io.IOBase], str]) -> Any:
14+
def encode_json(
15+
obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401
16+
) -> Any: # noqa: ANN401
1517
"""
1618
Returns a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json.
1719
"""

replicate/model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Dict, List, Union
2+
13
from replicate.base_model import BaseModel
24
from replicate.collection import Collection
35
from replicate.exceptions import ReplicateException
@@ -8,21 +10,34 @@ class Model(BaseModel):
810
username: str
911
name: str
1012

11-
def predict(self, *args, **kwargs):
13+
def predict(self, *args, **kwargs) -> None:
1214
raise ReplicateException(
1315
"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme"
1416
)
1517

1618
@property
17-
def versions(self):
19+
def versions(self) -> VersionCollection:
1820
return VersionCollection(client=self._client, model=self)
1921

2022

2123
class ModelCollection(Collection):
2224
model = Model
2325

26+
def list(self) -> List[Model]:
27+
raise NotImplementedError()
28+
2429
def get(self, name: str) -> Model:
2530
# TODO: fetch model from server
2631
# TODO: support permanent IDs
2732
username, name = name.split("/")
2833
return self.prepare_model({"username": username, "name": name})
34+
35+
def create(self, **kwargs) -> Model:
36+
raise NotImplementedError()
37+
38+
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
39+
if isinstance(attrs, BaseModel):
40+
attrs.id = f"{attrs.username}/{attrs.name}"
41+
elif isinstance(attrs, dict):
42+
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
43+
return super().prepare_model(attrs)

replicate/prediction.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Prediction(BaseModel):
2121
created_at: Optional[str]
2222
completed_at: Optional[str]
2323

24-
def wait(self):
24+
def wait(self) -> None:
2525
"""Wait for prediction to finish."""
2626
while self.status not in ["succeeded", "failed", "canceled"]:
2727
time.sleep(self._client.poll_interval)
@@ -47,21 +47,38 @@ def output_iterator(self) -> Iterator[Any]:
4747
for output in new_output:
4848
yield output
4949

50-
def cancel(self):
50+
def cancel(self) -> None:
5151
"""Cancel a currently running prediction"""
5252
self._client._request("POST", f"/v1/predictions/{self.id}/cancel")
5353

5454

5555
class PredictionCollection(Collection):
5656
model = Prediction
5757

58-
def create(
58+
def list(self) -> List[Prediction]:
59+
resp = self._client._request("GET", "/v1/predictions")
60+
# TODO: paginate
61+
predictions = resp.json()["results"]
62+
for prediction in predictions:
63+
# HACK: resolve this? make it lazy somehow?
64+
del prediction["version"]
65+
return [self.prepare_model(obj) for obj in predictions]
66+
67+
def get(self, id: str) -> Prediction:
68+
resp = self._client._request("GET", f"/v1/predictions/{id}")
69+
obj = resp.json()
70+
# HACK: resolve this? make it lazy somehow?
71+
del obj["version"]
72+
return self.prepare_model(obj)
73+
74+
def create( # type: ignore
5975
self,
6076
version: Version,
6177
input: Dict[str, Any],
6278
webhook: Optional[str] = None,
6379
webhook_completed: Optional[str] = None,
6480
webhook_events_filter: Optional[List[str]] = None,
81+
**kwargs,
6582
) -> Prediction:
6683
input = encode_json(input, upload_file=upload_file)
6784
body = {
@@ -83,19 +100,3 @@ def create(
83100
obj = resp.json()
84101
obj["version"] = version
85102
return self.prepare_model(obj)
86-
87-
def get(self, id: str) -> Prediction:
88-
resp = self._client._request("GET", f"/v1/predictions/{id}")
89-
obj = resp.json()
90-
# HACK: resolve this? make it lazy somehow?
91-
del obj["version"]
92-
return self.prepare_model(obj)
93-
94-
def list(self) -> List[Prediction]:
95-
resp = self._client._request("GET", "/v1/predictions")
96-
# TODO: paginate
97-
predictions = resp.json()["results"]
98-
for prediction in predictions:
99-
# HACK: resolve this? make it lazy somehow?
100-
del prediction["version"]
101-
return [self.prepare_model(obj) for obj in predictions]

replicate/schema.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
# TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth?
44

55

6-
def version_has_no_array_type(cog_version):
6+
def version_has_no_array_type(cog_version: str) -> bool:
77
"""Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward"""
88
return version.parse(cog_version) < version.parse("0.3.9")
99

1010

11-
def make_schema_backwards_compatible(schema, version):
11+
def make_schema_backwards_compatible(
12+
schema: dict,
13+
version: str,
14+
) -> dict:
1215
"""A place to add backwards compatibility logic for our openapi schema"""
1316
# If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type
1417
if version_has_no_array_type(version):

0 commit comments

Comments
 (0)