Skip to content

Type check Python source code with ruff and mypy #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ jobs:
python-version: ${{ matrix.python-version }}
cache: "pip"
- name: Install dependencies
run: python -m pip install -r requirements.txt -r requirements-dev.txt .
run: |
python -m pip install -r requirements.txt -r requirements-dev.txt .
yes | python -m mypy --install-types replicate || true

- name: Lint
run: python -m ruff .
run: |
python -m mypy replicate
python -m ruff .
python -m black --check .
- name: Test
run: python -m pytest

Expand Down
24 changes: 20 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ license = { file = "LICENSE" }
authors = [{ name = "Replicate, Inc." }]
requires-python = ">=3.8"
dependencies = ["packaging", "pydantic>1", "requests>2"]
optional-dependencies = { dev = ["black", "pytest", "responses", "ruff"] }
optional-dependencies = { dev = [
"black",
"mypy",
"pytest",
"responses",
"ruff",
] }

[project.urls]
homepage = "https://replicate.com"
Expand All @@ -31,11 +37,21 @@ select = [
"BLE", # flake8-blind-except
"FBT", # flake8-boolean-trap
"B", # flake8-bugbear
"ANN", # flake8-annotations
]
ignore = [
"E501", # Line too long
"S113", # Probable use of requests call without timeout
"E501", # Line too long
"S113", # Probable use of requests call without timeout
"ANN001", # Missing type annotation for function argument
"ANN002", # Missing type annotation for `*args`
"ANN003", # Missing type annotation for `**kwargs`
"ANN101", # Missing type annotation for self in method
"ANN102", # Missing type annotation for cls in classmethod
]

[tool.ruff.per-file-ignores]
"tests/*" = ["S101", "S106"]
"tests/*" = [
"S101", # Use of assert
"S106", # Possible use of hard-coded password function arguments
"ANN201", # Missing return type annotation for public function
]
17 changes: 10 additions & 7 deletions replicate/base_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import ForwardRef
from typing import TYPE_CHECKING

import pydantic
if TYPE_CHECKING:
from replicate.client import Client
from replicate.collection import Collection

Client = ForwardRef("Client")
Collection = ForwardRef("Collection")
import pydantic


class BaseModel(pydantic.BaseModel):
"""
A base class for representing a single object on the server.
"""

_client: Client = pydantic.PrivateAttr()
_collection: Collection = pydantic.PrivateAttr()
id: str

_client: "Client" = pydantic.PrivateAttr()
_collection: "Collection" = pydantic.PrivateAttr()

def reload(self):
def reload(self) -> None:
"""
Load this object from the server again.
"""
Expand Down
16 changes: 8 additions & 8 deletions replicate/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import re
from json import JSONDecodeError
from typing import Any, Iterator, Union
from typing import Any, Dict, Iterator, Optional, Union

import requests
from requests.adapters import HTTPAdapter, Retry
Expand All @@ -14,7 +14,7 @@


class Client:
def __init__(self, api_token=None) -> None:
def __init__(self, api_token: Optional[str] = None) -> None:
super().__init__()
# Client is instantiated at import time, so do as little as possible.
# This includes resolving environment variables -- they might be set programmatically.
Expand All @@ -30,7 +30,7 @@ def __init__(self, api_token=None) -> None:
total=5,
backoff_factor=2,
# Only retry 500s on GET so we don't unintionally mutute data
method_whitelist=["GET"],
allowed_methods=["GET"],
# https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors
status_forcelist=[
429,
Expand All @@ -54,14 +54,14 @@ def __init__(self, api_token=None) -> None:
write_retries = Retry(
total=5,
backoff_factor=2,
method_whitelist=["POST", "PUT"],
allowed_methods=["POST", "PUT"],
# Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data
status_forcelist=[429],
)
self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries))
self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries))

def _request(self, method: str, path: str, **kwargs):
def _request(self, method: str, path: str, **kwargs) -> requests.Response:
# from requests.Session
if method in ["GET", "OPTIONS"]:
kwargs.setdefault("allow_redirects", True)
Expand All @@ -81,13 +81,13 @@ def _request(self, method: str, path: str, **kwargs):
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
return resp

def _headers(self):
def _headers(self) -> Dict[str, str]:
return {
"Authorization": f"Token {self._api_token()}",
"User-Agent": f"replicate-python@{__version__}",
}

def _api_token(self):
def _api_token(self) -> str:
token = self.api_token
# Evaluate lazily in case environment variable is set with dotenv, or something
if token is None:
Expand All @@ -112,7 +112,7 @@ def predictions(self) -> PredictionCollection:
def trainings(self) -> TrainingCollection:
return TrainingCollection(client=self)

def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]:
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
"""
Run a model in the format owner/name:version.
"""
Expand Down
44 changes: 30 additions & 14 deletions replicate/collection.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,54 @@
import abc
from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast

if TYPE_CHECKING:
from replicate.client import Client

from replicate.base_model import BaseModel

Model = TypeVar("Model", bound=BaseModel)


class Collection:
class Collection(abc.ABC, Generic[Model]):
"""
A base class for representing all objects of a particular type on the
server.
"""

model: BaseModel = None

def __init__(self, client=None):
def __init__(self, client: "Client") -> None:
self._client = client

def list(self):
raise NotImplementedError
@abc.abstractproperty
def model(self) -> Model:
pass

@abc.abstractmethod
def list(self) -> List[Model]:
pass

def get(self, key):
raise NotImplementedError
@abc.abstractmethod
def get(self, key: str) -> Model:
pass

def create(self, attrs=None):
raise NotImplementedError
@abc.abstractmethod
def create(self, **kwargs) -> Model:
pass

def prepare_model(self, attrs):
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
"""
Create a model from a set of attributes.
"""
if isinstance(attrs, BaseModel):
attrs._client = self._client
attrs._collection = self
return attrs
elif isinstance(attrs, dict):
return cast(Model, attrs)
elif (
isinstance(attrs, dict) and self.model is not None and callable(self.model)
):
model = self.model(**attrs)
model._client = self._client
model._collection = self
return model
else:
raise Exception(f"Can't create {self.model.__name__} from {attrs}")
name = self.model.__name__ if hasattr(self.model, "__name__") else "model"
raise Exception(f"Can't create {name} from {attrs}")
3 changes: 2 additions & 1 deletion replicate/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import io
import mimetypes
import os
from typing import Optional

import requests


def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
"""
Lifted straight from cog.files
"""
Expand Down
4 changes: 3 additions & 1 deletion replicate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
has_numpy = False


def encode_json(obj: Any, upload_file: Callable[[io.IOBase], str]) -> Any:
def encode_json(
obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401
) -> Any: # noqa: ANN401
"""
Returns a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json.
"""
Expand Down
19 changes: 17 additions & 2 deletions replicate/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict, List, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ReplicateException
Expand All @@ -8,21 +10,34 @@ class Model(BaseModel):
username: str
name: str

def predict(self, *args, **kwargs):
def predict(self, *args, **kwargs) -> None:
raise ReplicateException(
"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme"
)

@property
def versions(self):
def versions(self) -> VersionCollection:
return VersionCollection(client=self._client, model=self)


class ModelCollection(Collection):
model = Model

def list(self) -> List[Model]:
raise NotImplementedError()

def get(self, name: str) -> Model:
# TODO: fetch model from server
# TODO: support permanent IDs
username, name = name.split("/")
return self.prepare_model({"username": username, "name": name})

def create(self, **kwargs) -> Model:
raise NotImplementedError()

def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
if isinstance(attrs, BaseModel):
attrs.id = f"{attrs.username}/{attrs.name}"
elif isinstance(attrs, dict):
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
return super().prepare_model(attrs)
39 changes: 20 additions & 19 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Prediction(BaseModel):
created_at: Optional[str]
completed_at: Optional[str]

def wait(self):
def wait(self) -> None:
"""Wait for prediction to finish."""
while self.status not in ["succeeded", "failed", "canceled"]:
time.sleep(self._client.poll_interval)
Expand All @@ -47,21 +47,38 @@ def output_iterator(self) -> Iterator[Any]:
for output in new_output:
yield output

def cancel(self):
def cancel(self) -> None:
"""Cancel a currently running prediction"""
self._client._request("POST", f"/v1/predictions/{self.id}/cancel")


class PredictionCollection(Collection):
model = Prediction

def create(
def list(self) -> List[Prediction]:
resp = self._client._request("GET", "/v1/predictions")
# TODO: paginate
predictions = resp.json()["results"]
for prediction in predictions:
# HACK: resolve this? make it lazy somehow?
del prediction["version"]
return [self.prepare_model(obj) for obj in predictions]

def get(self, id: str) -> Prediction:
resp = self._client._request("GET", f"/v1/predictions/{id}")
obj = resp.json()
# HACK: resolve this? make it lazy somehow?
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
self,
version: Version,
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
**kwargs,
) -> Prediction:
input = encode_json(input, upload_file=upload_file)
body = {
Expand All @@ -83,19 +100,3 @@ def create(
obj = resp.json()
obj["version"] = version
return self.prepare_model(obj)

def get(self, id: str) -> Prediction:
resp = self._client._request("GET", f"/v1/predictions/{id}")
obj = resp.json()
# HACK: resolve this? make it lazy somehow?
del obj["version"]
return self.prepare_model(obj)

def list(self) -> List[Prediction]:
resp = self._client._request("GET", "/v1/predictions")
# TODO: paginate
predictions = resp.json()["results"]
for prediction in predictions:
# HACK: resolve this? make it lazy somehow?
del prediction["version"]
return [self.prepare_model(obj) for obj in predictions]
7 changes: 5 additions & 2 deletions replicate/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
# TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth?


def version_has_no_array_type(cog_version):
def version_has_no_array_type(cog_version: str) -> bool:
"""Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward"""
return version.parse(cog_version) < version.parse("0.3.9")


def make_schema_backwards_compatible(schema, version):
def make_schema_backwards_compatible(
schema: dict,
version: str,
) -> dict:
"""A place to add backwards compatibility logic for our openapi schema"""
# If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type
if version_has_no_array_type(version):
Expand Down
Loading