Skip to content

Commit

Permalink
Enable mypy in CI 1/2 (#5328)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Sep 30, 2021
1 parent a468036 commit 7a3ea4c
Show file tree
Hide file tree
Showing 67 changed files with 433 additions and 327 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ repos:
hooks:
- id: pyupgrade
args:
- "--py37-plus"
- --py37-plus
2 changes: 1 addition & 1 deletion distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import config # isort:skip; load distributed configuration first
from . import widgets # isort:skip; load distributed widgets second
import dask
from dask.config import config
from dask.config import config # type: ignore
from dask.utils import import_required

from ._version import get_versions
Expand Down
15 changes: 7 additions & 8 deletions distributed/_concurrent_futures_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@

"""Implements ThreadPoolExecutor."""

from __future__ import annotations

__author__ = "Brian Quinlan (brian@sweetapp.com)"

import atexit
import itertools
from concurrent.futures import _base

try:
import queue
except ImportError:
import Queue as queue

import os
import queue
import threading
import weakref
from concurrent.futures import _base

# Workers are created as daemon threads. This is done to allow the interpreter
# to exit when there are still idle threads in a ThreadPoolExecutor's thread
Expand All @@ -34,7 +31,9 @@
# workers to exit when their work queues are empty and then waits until the
# threads finish.

_threads_queues = weakref.WeakKeyDictionary()
_threads_queues: weakref.WeakKeyDictionary[
threading.Thread, queue.Queue
] = weakref.WeakKeyDictionary()
_shutdown = False


Expand Down
10 changes: 2 additions & 8 deletions distributed/_ipython_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import atexit
import os

try:
import queue
except ImportError:
# Python 2
import Queue as queue

import queue
import sys
from subprocess import Popen
from threading import Event, Thread
Expand Down Expand Up @@ -135,7 +129,7 @@ def remote_magic(line, cell=None):


# cache clients for re-use in remote magic
remote_magic._clients = {}
remote_magic._clients = {} # type: ignore


def register_remote_magic(magic_name="remote"):
Expand Down
2 changes: 1 addition & 1 deletion distributed/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""


LONG_VERSION_PY = {}
LONG_VERSION_PY: dict = {}
HANDLERS = {}


Expand Down
20 changes: 11 additions & 9 deletions distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .utils import import_term

if TYPE_CHECKING:
from .scheduler import SchedulerState, TaskState, WorkerState
from .scheduler import Scheduler, TaskState, WorkerState


class ActiveMemoryManagerExtension:
Expand All @@ -31,7 +31,7 @@ class ActiveMemoryManagerExtension:
``distributed.scheduler.active-memory-manager``.
"""

scheduler: SchedulerState
scheduler: Scheduler
policies: set[ActiveMemoryManagerPolicy]
interval: float

Expand All @@ -43,7 +43,7 @@ class ActiveMemoryManagerExtension:

def __init__(
self,
scheduler: SchedulerState,
scheduler: Scheduler,
# The following parameters are exposed so that one may create, run, and throw
# away on the fly a specialized manager, separate from the main one.
policies: set[ActiveMemoryManagerPolicy] | None = None,
Expand Down Expand Up @@ -126,12 +126,14 @@ def run_once(self, comm=None) -> None:
# populate self.pending
self._run_policies()

drop_by_worker = defaultdict(set)
repl_by_worker = defaultdict(dict)
drop_by_worker: defaultdict[str, set[str]] = defaultdict(set)
repl_by_worker: defaultdict[str, dict[str, list[str]]] = defaultdict(dict)

for ts, (pending_repl, pending_drop) in self.pending.items():
if not ts.who_has:
continue
who_has = [ws_snd.address for ws_snd in ts.who_has - pending_drop]

assert who_has # Never drop the last replica
for ws_rec in pending_repl:
assert ws_rec not in ts.who_has
Expand All @@ -143,8 +145,8 @@ def run_once(self, comm=None) -> None:
# Fire-and-forget enact recommendations from policies
# This is temporary code, waiting for
# https://github.com/dask/distributed/pull/5046
for addr, who_has in repl_by_worker.items():
asyncio.create_task(self.scheduler.gather_on_worker(addr, who_has))
for addr, who_has_map in repl_by_worker.items():
asyncio.create_task(self.scheduler.gather_on_worker(addr, who_has_map))
for addr, keys in drop_by_worker.items():
asyncio.create_task(self.scheduler.delete_worker_data(addr, keys))
# End temporary code
Expand Down Expand Up @@ -215,7 +217,7 @@ def _find_recipient(
candidates -= pending_repl
if not candidates:
return None
return min(candidates, key=self.workers_memory.get)
return min(candidates, key=self.workers_memory.__getitem__)

def _find_dropper(
self,
Expand Down Expand Up @@ -244,7 +246,7 @@ def _find_dropper(
candidates -= {waiter_ts.processing_on for waiter_ts in ts.waiters}
if not candidates:
return None
return max(candidates, key=self.workers_memory.get)
return max(candidates, key=self.workers_memory.__getitem__)


class ActiveMemoryManagerPolicy:
Expand Down
19 changes: 9 additions & 10 deletions distributed/cli/tests/test_dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,22 @@ def test_scheduler_port_zero(loop):
with tmpfile() as fn:
with popen(
["dask-scheduler", "--no-dashboard", "--scheduler-file", fn, "--port", "0"]
) as sched:
):
with Client(scheduler_file=fn, loop=loop) as c:
assert c.scheduler.port
assert c.scheduler.port != 8786


def test_dashboard_port_zero(loop):
pytest.importorskip("bokeh")
with tmpfile() as fn:
with popen(["dask-scheduler", "--dashboard-address", ":0"]) as proc:
count = 0
while count < 1:
line = proc.stderr.readline()
if b"dashboard" in line.lower():
sleep(0.01)
count += 1
assert b":0" not in line
with popen(["dask-scheduler", "--dashboard-address", ":0"]) as proc:
count = 0
while count < 1:
line = proc.stderr.readline()
if b"dashboard" in line.lower():
sleep(0.01)
count += 1
assert b":0" not in line


PRELOAD_TEXT = """
Expand Down
26 changes: 15 additions & 11 deletions distributed/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import atexit
import copy
Expand All @@ -22,6 +24,7 @@
from functools import partial
from numbers import Number
from queue import Queue as pyQueue
from typing import ClassVar

from tlz import first, groupby, keymap, merge, partition_all, valmap

Expand Down Expand Up @@ -49,7 +52,7 @@
from tornado import gen
from tornado.ioloop import IOLoop, PeriodicCallback

from . import versions as version_module
from . import versions as version_module # type: ignore
from .batched import BatchedSend
from .cfexecutor import ClientExecutor
from .core import (
Expand Down Expand Up @@ -95,7 +98,9 @@

logger = logging.getLogger(__name__)

_global_clients = weakref.WeakValueDictionary()
_global_clients: weakref.WeakValueDictionary[
int, Client
] = weakref.WeakValueDictionary()
_global_client_index = [0]

_current_client = ContextVar("_current_client", default=None)
Expand All @@ -105,7 +110,7 @@
NO_DEFAULT_PLACEHOLDER = "_no_default_"


def _get_global_client():
def _get_global_client() -> Client | None:
L = sorted(list(_global_clients), reverse=True)
for k in L:
c = _global_clients[k]
Expand All @@ -116,13 +121,13 @@ def _get_global_client():
return None


def _set_global_client(c):
def _set_global_client(c: Client | None) -> None:
if c is not None:
_global_clients[_global_client_index[0]] = c
_global_client_index[0] += 1


def _del_global_client(c):
def _del_global_client(c: Client) -> None:
for k in list(_global_clients):
try:
if _global_clients[k] is c:
Expand Down Expand Up @@ -590,7 +595,7 @@ class Client:
distributed.LocalCluster:
"""

_instances = weakref.WeakSet()
_instances: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()

_default_event_handlers = {"print": _handle_print, "warn": _handle_warn}

Expand Down Expand Up @@ -1377,8 +1382,6 @@ async def _close(self, fast=False):

self.status = "closed"

_shutdown = _close

def close(self, timeout=no_default):
"""Close this client
Expand Down Expand Up @@ -2529,12 +2532,13 @@ def _get_computation_code() -> str:
)
if not isinstance(ignore_modules, list):
raise TypeError(
f"Ignored modules must be a list. Instead got ({type(ignore_modules)}, {ignore_modules})"
"Ignored modules must be a list. Instead got "
f"({type(ignore_modules)}, {ignore_modules})"
)

pattern: re.Pattern | None
if ignore_modules:
pattern = "|".join([f"(?:{mod})" for mod in ignore_modules])
pattern = re.compile(pattern)
pattern = re.compile("|".join([f"(?:{mod})" for mod in ignore_modules]))
else:
pattern = None

Expand Down
28 changes: 17 additions & 11 deletions distributed/comm/addressing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import itertools

import dask
Expand All @@ -6,7 +8,7 @@
from . import registry


def parse_address(addr, strict=False):
def parse_address(addr: str, strict: bool = False) -> tuple[str, str]:
"""
Split address into its scheme and scheme-dependent location string.
Expand All @@ -30,7 +32,7 @@ def parse_address(addr, strict=False):
return scheme, loc


def unparse_address(scheme, loc):
def unparse_address(scheme: str, loc: str) -> str:
"""
Undo parse_address().
Expand All @@ -40,7 +42,7 @@ def unparse_address(scheme, loc):
return f"{scheme}://{loc}"


def normalize_address(addr):
def normalize_address(addr: str) -> str:
"""
Canonicalize address, adding a default scheme if necessary.
Expand All @@ -52,7 +54,9 @@ def normalize_address(addr):
return unparse_address(*parse_address(addr))


def parse_host_port(address, default_port=None):
def parse_host_port(
address: str | tuple[str, int], default_port: str | int | None = None
) -> tuple[str, int]:
"""
Parse an endpoint address given in the form "host:port".
"""
Expand Down Expand Up @@ -95,19 +99,19 @@ def _default():
return host, int(port)


def unparse_host_port(host, port=None):
def unparse_host_port(host: str, port: int | None = None) -> str:
"""
Undo parse_host_port().
"""
if ":" in host and not host.startswith("["):
host = "[%s]" % host
host = f"[{host}]"
if port is not None:
return f"{host}:{port}"
else:
return host


def get_address_host_port(addr, strict=False):
def get_address_host_port(addr: str, strict: bool = False) -> tuple[str, int]:
"""
Get a (host, port) tuple out of the given address.
For definition of strict check parse_address
Expand All @@ -129,7 +133,7 @@ def get_address_host_port(addr, strict=False):
)


def get_address_host(addr):
def get_address_host(addr: str) -> str:
"""
Return a hostname / IP address identifying the machine this address
is located on.
Expand All @@ -145,7 +149,7 @@ def get_address_host(addr):
return backend.get_address_host(loc)


def get_local_address_for(addr):
def get_local_address_for(addr: str) -> str:
"""
Get a local listening address suitable for reaching *addr*.
Expand All @@ -162,7 +166,7 @@ def get_local_address_for(addr):
return unparse_address(scheme, backend.get_local_address_for(loc))


def resolve_address(addr):
def resolve_address(addr: str) -> str:
"""
Apply scheme-specific address resolution to *addr*, replacing
all symbolic references with concrete location specifiers.
Expand All @@ -177,7 +181,9 @@ def resolve_address(addr):
return unparse_address(scheme, backend.resolve_address(loc))


def uri_from_host_port(host_arg, port_arg, default_port):
def uri_from_host_port(
host_arg: str | None, port_arg: str | None, default_port: int
) -> str:
"""
Process the *host* and *port* CLI options.
Return a URI.
Expand Down
Loading

0 comments on commit 7a3ea4c

Please sign in to comment.