Skip to content

Lint Python sources with ruff #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
cache: "pip"
- name: Install dependencies
run: python -m pip install -r requirements.txt -r requirements-dev.txt .
- name: Lint
run: python -m ruff .
- name: Test
run: python -m pytest

Expand Down
22 changes: 21 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,31 @@ license = { file = "LICENSE" }
authors = [{ name = "Replicate, Inc." }]
requires-python = ">=3.8"
dependencies = ["packaging", "pydantic>1", "requests>2"]
optional-dependencies = { dev = ["black", "pytest", "responses"] }
optional-dependencies = { dev = ["black", "pytest", "responses", "ruff"] }

[project.urls]
homepage = "https://replicate.com"
repository = "https://github.com/replicate/replicate-python"

[tool.pytest.ini_options]
testpaths = "tests/"

[tool.ruff]
select = [
"E", # pycodestyle error
"F", # Pyflakes
"I", # isort
"W", # pycodestyle warning
"UP", # pyupgrade
"S", # flake8-bandit
"BLE", # flake8-blind-except
"FBT", # flake8-boolean-trap
"B", # flake8-bugbear
]
ignore = [
"E501", # Line too long
"S113", # Probable use of requests call without timeout
]

[tool.ruff.per-file-ignores]
"tests/*" = ["S101", "S106"]
2 changes: 1 addition & 1 deletion replicate/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import ForwardRef
import pydantic

import pydantic

Client = ForwardRef("Client")
Collection = ForwardRef("Collection")
Expand Down
2 changes: 1 addition & 1 deletion replicate/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ def prepare_model(self, attrs):
model._collection = self
return model
else:
raise Exception("Can't create %s from %s" % (self.model.__name__, attrs))
raise Exception(f"Can't create {self.model.__name__} from {attrs}")
2 changes: 1 addition & 1 deletion replicate/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
if output_file_prefix is not None:
name = getattr(fh, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = requests.put(url, files={"file": fh})
resp = requests.put(url, files={"file": fh}, timeout=None)
resp.raise_for_status()
return url

Expand Down
1 change: 0 additions & 1 deletion replicate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from types import GeneratorType
from typing import Any, Callable


try:
import numpy as np # type: ignore

Expand Down
4 changes: 1 addition & 3 deletions replicate/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ReplicateException
Expand All @@ -12,7 +10,7 @@ class Model(BaseModel):

def predict(self, *args, **kwargs):
raise ReplicateException(
f"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme"
"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme"
)

@property
Expand Down
4 changes: 2 additions & 2 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ModelError, ReplicateException
from replicate.exceptions import ModelError
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.version import Version
Expand Down Expand Up @@ -92,7 +92,7 @@ def get(self, id: str) -> Prediction:
return self.prepare_model(obj)

def list(self) -> List[Prediction]:
resp = self._client._request("GET", f"/v1/predictions")
resp = self._client._request("GET", "/v1/predictions")
# TODO: paginate
predictions = resp.json()["results"]
for prediction in predictions:
Expand Down
8 changes: 3 additions & 5 deletions replicate/training.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import re
import time
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, List, Optional

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ModelError, ReplicateException
from replicate.exceptions import ReplicateException
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.version import Version


class Training(BaseModel):
Expand Down Expand Up @@ -55,7 +53,7 @@ def create(
)
if not match:
raise ReplicateException(
f"version must be in format username/model_name:version_id"
"version must be in format username/model_name:version_id"
)
username = match.group("username")
model_name = match.group("model_name")
Expand Down
1 change: 1 addition & 0 deletions replicate/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]:
warnings.warn(
"version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.",
DeprecationWarning,
stacklevel=1,
)

prediction = self._client.predictions.create(version=self, input=kwargs)
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ requests==2.28.2
# responses
responses==0.23.1
# via replicate (pyproject.toml)
ruff==0.0.261
# via replicate (pyproject.toml)
types-pyyaml==6.0.12.9
# via responses
typing-extensions==4.5.0
Expand Down
8 changes: 3 additions & 5 deletions tests/test_prediction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import responses
from responses import matchers

import replicate

from .factories import create_client, create_version


Expand Down Expand Up @@ -41,7 +39,7 @@ def test_create_works_with_webhooks():
},
)

prediction = client.predictions.create(
client.predictions.create(
version=version,
input={"text": "world"},
webhook="https://example.com/webhook",
Expand Down Expand Up @@ -156,8 +154,8 @@ def test_async_timings():
)

assert prediction.created_at == "2022-04-26T20:00:40.658234Z"
assert prediction.completed_at == None
assert prediction.output == None
assert prediction.completed_at is None
assert prediction.output is None
prediction.wait()
assert prediction.created_at == "2022-04-26T20:00:40.658234Z"
assert prediction.completed_at == "2022-04-26T20:02:27.648305Z"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import pytest
import responses
from replicate.exceptions import ModelError
from responses import matchers

from replicate.exceptions import ModelError

from .factories import (
create_version,
create_version_with_iterator_output,
Expand Down