Skip to content

Commit

Permalink
Enable mypy in CI
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Sep 18, 2021
1 parent 05677bb commit 62ded0f
Show file tree
Hide file tree
Showing 56 changed files with 207 additions and 174 deletions.
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
- [ ] Closes #xxxx
- [ ] Tests added / passed
- [ ] Passes `black distributed` / `flake8 distributed` / `isort distributed`
- [ ] Passes `pre-commit run --all-files`
18 changes: 16 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.8.0
rev: 5.9.3
hooks:
- id: isort
language_version: python3
- repo: https://github.com/psf/black
rev: 21.5b1
rev: 21.9b0
hooks:
- id: black
language_version: python3
Expand All @@ -17,3 +17,17 @@ repos:
hooks:
- id: flake8
language_version: python3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910
hooks:
- id: mypy
additional_dependencies: [
# Type stubs
types-docutils,
types-requests,
types-paramiko,
types-pkg_resources,
types-PyYAML,
types-setuptools,
types-psutil,
]
2 changes: 1 addition & 1 deletion distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import config # isort:skip; load distributed configuration first
from . import widgets # isort:skip; load distributed widgets second
import dask
from dask.config import config
from dask.config import config # type: ignore
from dask.utils import import_required

from ._version import get_versions
Expand Down
15 changes: 7 additions & 8 deletions distributed/_concurrent_futures_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@

"""Implements ThreadPoolExecutor."""

from __future__ import annotations

__author__ = "Brian Quinlan (brian@sweetapp.com)"

import atexit
import itertools
from concurrent.futures import _base

try:
import queue
except ImportError:
import Queue as queue

import os
import queue
import threading
import weakref
from concurrent.futures import _base

# Workers are created as daemon threads. This is done to allow the interpreter
# to exit when there are still idle threads in a ThreadPoolExecutor's thread
Expand All @@ -34,7 +31,9 @@
# workers to exit when their work queues are empty and then waits until the
# threads finish.

_threads_queues = weakref.WeakKeyDictionary()
_threads_queues: weakref.WeakKeyDictionary[
threading.Thread, queue.Queue
] = weakref.WeakKeyDictionary()
_shutdown = False


Expand Down
10 changes: 2 additions & 8 deletions distributed/_ipython_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import atexit
import os

try:
import queue
except ImportError:
# Python 2
import Queue as queue

import queue
import sys
from subprocess import Popen
from threading import Event, Thread
Expand Down Expand Up @@ -135,7 +129,7 @@ def remote_magic(line, cell=None):


# cache clients for re-use in remote magic
remote_magic._clients = {}
remote_magic._clients = {} # type: ignore


def register_remote_magic(magic_name="remote"):
Expand Down
2 changes: 1 addition & 1 deletion distributed/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""


LONG_VERSION_PY = {}
LONG_VERSION_PY: dict = {}
HANDLERS = {}


Expand Down
24 changes: 12 additions & 12 deletions distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from collections import defaultdict
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from tornado.ioloop import PeriodicCallback

Expand Down Expand Up @@ -46,10 +46,10 @@ def __init__(
scheduler: SchedulerState,
# The following parameters are exposed so that one may create, run, and throw
# away on the fly a specialized manager, separate from the main one.
policies: Optional[set[ActiveMemoryManagerPolicy]] = None,
policies: set[ActiveMemoryManagerPolicy] | None = None,
register: bool = True,
start: Optional[bool] = None,
interval: Optional[float] = None,
start: bool | None = None,
interval: float | None = None,
):
self.scheduler = scheduler
self.policies = set()
Expand Down Expand Up @@ -157,9 +157,9 @@ def _run_policies(self) -> None:
"""Sequentially run ActiveMemoryManagerPolicy.run() for all registered policies,
obtain replicate/drop suggestions, and use them to populate self.pending.
"""
candidates: Optional[set[WorkerState]]
candidates: set[WorkerState] | None
cmd: str
ws: Optional[WorkerState]
ws: WorkerState | None
ts: TaskState
nreplicas: int

Expand Down Expand Up @@ -194,9 +194,9 @@ def _run_policies(self) -> None:
def _find_recipient(
self,
ts: TaskState,
candidates: Optional[set[WorkerState]],
candidates: set[WorkerState] | None,
pending_repl: set[WorkerState],
) -> Optional[WorkerState]:
) -> WorkerState | None:
"""Choose a worker to acquire a new replica of an in-memory task among a set of
candidates. If candidates is None, default to all workers in the cluster.
Regardless, workers that either already hold a replica or are scheduled to
Expand All @@ -220,9 +220,9 @@ def _find_recipient(
def _find_dropper(
self,
ts: TaskState,
candidates: Optional[set[WorkerState]],
candidates: set[WorkerState] | None,
pending_drop: set[WorkerState],
) -> Optional[WorkerState]:
) -> WorkerState | None:
"""Choose a worker to drop its replica of an in-memory task among a set of
candidates. If candidates is None, default to all workers in the cluster.
Regardless, workers that either do not hold a replica or are already scheduled
Expand Down Expand Up @@ -258,8 +258,8 @@ def __repr__(self) -> str:
def run(
self,
) -> Generator[
tuple[str, TaskState, Optional[set[WorkerState]]],
Optional[WorkerState],
tuple[str, TaskState, set[WorkerState] | None],
WorkerState | None,
None,
]:
"""This method is invoked by the ActiveMemoryManager every few seconds, or
Expand Down
4 changes: 2 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4159,7 +4159,7 @@ class will be instantiated with any extra keyword arguments.
... pass
... def transition(self, key: str, start: str, finish: str, **kwargs):
... pass
... def release_key(self, key: str, state: str, cause: Optional[str], reason: None, report: bool):
... def release_key(self, key: str, state: str, cause: str | None, reason: None, report: bool):
... pass
>>> plugin = MyPlugin(1, 2, 3)
Expand Down Expand Up @@ -4228,7 +4228,7 @@ def unregister_worker_plugin(self, name, nanny=None):
... pass
... def transition(self, key: str, start: str, finish: str, **kwargs):
... pass
... def release_key(self, key: str, state: str, cause: Optional[str], reason: None, report: bool):
... def release_key(self, key: str, state: str, cause: str | None, reason: None, report: bool):
... pass
>>> plugin = MyPlugin(1, 2, 3)
Expand Down
9 changes: 6 additions & 3 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import inspect
import logging
Expand All @@ -6,6 +8,7 @@
import weakref
from abc import ABC, abstractmethod, abstractproperty
from contextlib import suppress
from typing import ClassVar

import dask
from dask.utils import parse_timedelta
Expand Down Expand Up @@ -40,7 +43,7 @@ class Comm(ABC):
depending on the underlying transport's characteristics.
"""

_instances = weakref.WeakSet()
_instances: ClassVar[weakref.WeakSet[Comm]] = weakref.WeakSet()

def __init__(self):
self._instances.add(self)
Expand All @@ -61,7 +64,7 @@ async def read(self, deserializers=None):
Parameters
----------
deserializers : Optional[Dict[str, Tuple[Callable, Callable, bool]]]
deserializers : dict[str, tuple[Callable, Callable, bool]] | None
An optional dict appropriate for distributed.protocol.deserialize.
See :ref:`serialization` for more.
"""
Expand All @@ -76,7 +79,7 @@ async def write(self, msg, serializers=None, on_error=None):
Parameters
----------
msg
on_error : Optional[str]
on_error : str | None
The behavior when serialization fails. See
``distributed.protocol.core.dumps`` for valid values.
"""
Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod


Expand Down Expand Up @@ -54,7 +56,7 @@ def get_local_address_for(self, loc):


# The {scheme: Backend} mapping
backends = {}
backends: dict[str, Backend] = {}


def get_backend(scheme: str, require: bool = True) -> Backend:
Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
try:
import ssl
except ImportError:
ssl = None
ssl = None # type: ignore

from tlz import sliding_window
from tornado import netutil
Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/ws.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import asyncio
import logging
import struct
import warnings
import weakref
from collections.abc import Callable
from ssl import SSLError
from typing import Callable

from tornado import web
from tornado.httpclient import HTTPClientError, HTTPRequest
Expand Down
7 changes: 5 additions & 2 deletions distributed/compatibility.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import logging
import platform
import sys

import tornado

logging_names = logging._levelToName.copy()
logging_names.update(logging._nameToLevel)
logging_names: dict[str | int, int | str] = {}
logging_names.update(logging._levelToName) # type: ignore
logging_names.update(logging._nameToLevel) # type: ignore

PYPY = platform.python_implementation().lower() == "pypy"
LINUX = sys.platform == "linux"
Expand Down
7 changes: 5 additions & 2 deletions distributed/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import inspect
import logging
Expand All @@ -11,6 +13,7 @@
from contextlib import suppress
from enum import Enum
from functools import partial
from typing import ClassVar

import tblib
from tlz import merge
Expand Down Expand Up @@ -695,7 +698,7 @@ class rpc:
>>> remote.close_comms() # doctest: +SKIP
"""

active = weakref.WeakSet()
active: ClassVar[weakref.WeakSet[rpc]] = weakref.WeakSet()
comms = ()
address = None

Expand Down Expand Up @@ -928,7 +931,7 @@ class ConnectionPool:
Whether or not to deserialize data by default or pass it through
"""

_instances = weakref.WeakSet()
_instances: ClassVar[weakref.WeakSet[ConnectionPool]] = weakref.WeakSet()

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
if dask.config.get("distributed.dashboard.export-tool"):
from distributed.dashboard.export_tool import ExportTool
else:
ExportTool = None
ExportTool = None # type: ignore

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion distributed/dashboard/components/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
if dask.config.get("distributed.dashboard.export-tool"):
from distributed.dashboard.export_tool import ExportTool
else:
ExportTool = None
ExportTool = None # type: ignore


profile_interval = dask.config.get("distributed.worker.profile.interval")
Expand Down
3 changes: 2 additions & 1 deletion distributed/dashboard/tests/test_scheduler_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
from distributed.utils import format_dashboard_link
from distributed.utils_test import dec, div, gen_cluster, get_cert, inc, slowinc

scheduler.PROFILING = False
# Imported from distributed.dashboard.utils
scheduler.PROFILING = False # type: ignore


@gen_cluster(client=True, scheduler_kwargs={"dashboard": True})
Expand Down
2 changes: 1 addition & 1 deletion distributed/deploy/adaptive_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import logging
import math
from typing import Iterable
from collections.abc import Iterable

import tlz as toolz
from tornado.ioloop import IOLoop, PeriodicCallback
Expand Down
11 changes: 0 additions & 11 deletions distributed/deploy/local.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import atexit
import logging
import math
import warnings
import weakref

import toolz

Expand Down Expand Up @@ -258,12 +256,3 @@ def _repr_html_(self, cluster_status=None):
cluster_status=cluster_status,
)
return super()._repr_html_(cluster_status=cluster_status)


clusters_to_close = weakref.WeakSet()


@atexit.register
def close_clusters():
for cluster in list(clusters_to_close):
cluster.close(timeout=10)
Loading

0 comments on commit 62ded0f

Please sign in to comment.