Skip to content

Commit c892c78

Browse files
authored
Lint Python sources with ruff (#85)
* Add ruff as development dependency Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Add lint step to CI workflow Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Apply automatic fixes made by ruff Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Ignore rule E501 Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Explicitly select ruff rule sets Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Enable isort rules in ruff Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Apply automatic fixes made by ruff for isort violations Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Enable pyupgrade rules in ruff Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Apply automatic fix made by ruff for pyupgrade violation Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Enable flake8-bandit rules in ruff Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Enable flake8-boolean-trap rules in ruff Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Enable flake8-bugbear rules in ruff Signed-off-by: Mattt Zmuda <mattt@replicate.com> * Fix 'B028 No explicit `stacklevel` keyword argument found' Signed-off-by: Mattt Zmuda <mattt@replicate.com> --------- Signed-off-by: Mattt Zmuda <mattt@replicate.com>
1 parent 6094f5a commit c892c78

File tree

13 files changed

+40
-21
lines changed

13 files changed

+40
-21
lines changed

.github/workflows/ci.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ jobs:
2424
cache: "pip"
2525
- name: Install dependencies
2626
run: python -m pip install -r requirements.txt -r requirements-dev.txt .
27+
- name: Lint
28+
run: python -m ruff .
2729
- name: Test
2830
run: python -m pytest
2931

pyproject.toml

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,31 @@ license = { file = "LICENSE" }
1111
authors = [{ name = "Replicate, Inc." }]
1212
requires-python = ">=3.8"
1313
dependencies = ["packaging", "pydantic>1", "requests>2"]
14-
optional-dependencies = { dev = ["black", "pytest", "responses"] }
14+
optional-dependencies = { dev = ["black", "pytest", "responses", "ruff"] }
1515

1616
[project.urls]
1717
homepage = "https://replicate.com"
1818
repository = "https://github.com/replicate/replicate-python"
1919

2020
[tool.pytest.ini_options]
2121
testpaths = "tests/"
22+
23+
[tool.ruff]
24+
select = [
25+
"E", # pycodestyle error
26+
"F", # Pyflakes
27+
"I", # isort
28+
"W", # pycodestyle warning
29+
"UP", # pyupgrade
30+
"S", # flake8-bandit
31+
"BLE", # flake8-blind-except
32+
"FBT", # flake8-boolean-trap
33+
"B", # flake8-bugbear
34+
]
35+
ignore = [
36+
"E501", # Line too long
37+
"S113", # Probable use of requests call without timeout
38+
]
39+
40+
[tool.ruff.per-file-ignores]
41+
"tests/*" = ["S101", "S106"]

replicate/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import ForwardRef
2-
import pydantic
32

3+
import pydantic
44

55
Client = ForwardRef("Client")
66
Collection = ForwardRef("Collection")

replicate/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ def prepare_model(self, attrs):
3535
model._collection = self
3636
return model
3737
else:
38-
raise Exception("Can't create %s from %s" % (self.model.__name__, attrs))
38+
raise Exception(f"Can't create {self.model.__name__} from {attrs}")

replicate/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
1515
if output_file_prefix is not None:
1616
name = getattr(fh, "name", "output")
1717
url = output_file_prefix + os.path.basename(name)
18-
resp = requests.put(url, files={"file": fh})
18+
resp = requests.put(url, files={"file": fh}, timeout=None)
1919
resp.raise_for_status()
2020
return url
2121

replicate/json.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from types import GeneratorType
44
from typing import Any, Callable
55

6-
76
try:
87
import numpy as np # type: ignore
98

replicate/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import List
2-
31
from replicate.base_model import BaseModel
42
from replicate.collection import Collection
53
from replicate.exceptions import ReplicateException
@@ -12,7 +10,7 @@ class Model(BaseModel):
1210

1311
def predict(self, *args, **kwargs):
1412
raise ReplicateException(
15-
f"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme"
13+
"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme"
1614
)
1715

1816
@property

replicate/prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from replicate.base_model import BaseModel
55
from replicate.collection import Collection
6-
from replicate.exceptions import ModelError, ReplicateException
6+
from replicate.exceptions import ModelError
77
from replicate.files import upload_file
88
from replicate.json import encode_json
99
from replicate.version import Version
@@ -92,7 +92,7 @@ def get(self, id: str) -> Prediction:
9292
return self.prepare_model(obj)
9393

9494
def list(self) -> List[Prediction]:
95-
resp = self._client._request("GET", f"/v1/predictions")
95+
resp = self._client._request("GET", "/v1/predictions")
9696
# TODO: paginate
9797
predictions = resp.json()["results"]
9898
for prediction in predictions:

replicate/training.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import re
2-
import time
3-
from typing import Any, Dict, Iterator, List, Optional
2+
from typing import Any, Dict, List, Optional
43

54
from replicate.base_model import BaseModel
65
from replicate.collection import Collection
7-
from replicate.exceptions import ModelError, ReplicateException
6+
from replicate.exceptions import ReplicateException
87
from replicate.files import upload_file
98
from replicate.json import encode_json
10-
from replicate.version import Version
119

1210

1311
class Training(BaseModel):
@@ -55,7 +53,7 @@ def create(
5553
)
5654
if not match:
5755
raise ReplicateException(
58-
f"version must be in format username/model_name:version_id"
56+
"version must be in format username/model_name:version_id"
5957
)
6058
username = match.group("username")
6159
model_name = match.group("model_name")

replicate/version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]:
1818
warnings.warn(
1919
"version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.",
2020
DeprecationWarning,
21+
stacklevel=1,
2122
)
2223

2324
prediction = self._client.predictions.create(version=self, input=kwargs)

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ requests==2.28.2
4343
# responses
4444
responses==0.23.1
4545
# via replicate (pyproject.toml)
46+
ruff==0.0.261
47+
# via replicate (pyproject.toml)
4648
types-pyyaml==6.0.12.9
4749
# via responses
4850
typing-extensions==4.5.0

tests/test_prediction.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import responses
22
from responses import matchers
33

4-
import replicate
5-
64
from .factories import create_client, create_version
75

86

@@ -41,7 +39,7 @@ def test_create_works_with_webhooks():
4139
},
4240
)
4341

44-
prediction = client.predictions.create(
42+
client.predictions.create(
4543
version=version,
4644
input={"text": "world"},
4745
webhook="https://example.com/webhook",
@@ -156,8 +154,8 @@ def test_async_timings():
156154
)
157155

158156
assert prediction.created_at == "2022-04-26T20:00:40.658234Z"
159-
assert prediction.completed_at == None
160-
assert prediction.output == None
157+
assert prediction.completed_at is None
158+
assert prediction.output is None
161159
prediction.wait()
162160
assert prediction.created_at == "2022-04-26T20:00:40.658234Z"
163161
assert prediction.completed_at == "2022-04-26T20:02:27.648305Z"

tests/test_version.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import pytest
44
import responses
5-
from replicate.exceptions import ModelError
65
from responses import matchers
76

7+
from replicate.exceptions import ModelError
8+
89
from .factories import (
910
create_version,
1011
create_version_with_iterator_output,

0 commit comments

Comments
 (0)