Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mypy in CI 1/2 #5328

Merged
merged 24 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ repos:
hooks:
- id: pyupgrade
args:
- "--py37-plus"
- --py37-plus
2 changes: 1 addition & 1 deletion distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import config # isort:skip; load distributed configuration first
from . import widgets # isort:skip; load distributed widgets second
import dask
from dask.config import config
from dask.config import config # type: ignore
from dask.utils import import_required

from ._version import get_versions
Expand Down
15 changes: 7 additions & 8 deletions distributed/_concurrent_futures_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@

"""Implements ThreadPoolExecutor."""

from __future__ import annotations

__author__ = "Brian Quinlan (brian@sweetapp.com)"

import atexit
import itertools
from concurrent.futures import _base

try:
import queue
except ImportError:
import Queue as queue

import os
import queue
import threading
import weakref
from concurrent.futures import _base

# Workers are created as daemon threads. This is done to allow the interpreter
# to exit when there are still idle threads in a ThreadPoolExecutor's thread
Expand All @@ -34,7 +31,9 @@
# workers to exit when their work queues are empty and then waits until the
# threads finish.

_threads_queues = weakref.WeakKeyDictionary()
_threads_queues: weakref.WeakKeyDictionary[
threading.Thread, queue.Queue
] = weakref.WeakKeyDictionary()
_shutdown = False


Expand Down
10 changes: 2 additions & 8 deletions distributed/_ipython_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import atexit
import os

try:
import queue
except ImportError:
# Python 2
import Queue as queue

import queue
import sys
from subprocess import Popen
from threading import Event, Thread
Expand Down Expand Up @@ -135,7 +129,7 @@ def remote_magic(line, cell=None):


# cache clients for re-use in remote magic
remote_magic._clients = {}
remote_magic._clients = {} # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
remote_magic._clients = {} # type: ignore
remote_magic._clients: Dict[str, BlockingKernelClient] = {}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed. We should always use lowercase dict[...], list[...], etc. and put from __future__ import annotations at the top of the module.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction: can't.

distributed/_ipython_utils.py:133: error: Type cannot be declared in assignment to non-self attribute
distributed/_ipython_utils.py:133: error: "Callable[[Any, Any], Any]" has no attribute "_clients"



def register_remote_magic(magic_name="remote"):
Expand Down
2 changes: 1 addition & 1 deletion distributed/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""


LONG_VERSION_PY = {}
LONG_VERSION_PY: dict = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file is missing a from __future__ import __annotations to make this work.

Copy link
Collaborator Author

@crusaderky crusaderky Sep 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's only necessary if you use bracket-stile dict[...] or if the annotations create circular dependencies

HANDLERS = {}


Expand Down
20 changes: 11 additions & 9 deletions distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .utils import import_term

if TYPE_CHECKING:
from .scheduler import SchedulerState, TaskState, WorkerState
from .scheduler import Scheduler, TaskState, WorkerState


class ActiveMemoryManagerExtension:
Expand All @@ -31,7 +31,7 @@ class ActiveMemoryManagerExtension:
``distributed.scheduler.active-memory-manager``.
"""

scheduler: SchedulerState
scheduler: Scheduler
policies: set[ActiveMemoryManagerPolicy]
interval: float

Expand All @@ -43,7 +43,7 @@ class ActiveMemoryManagerExtension:

def __init__(
self,
scheduler: SchedulerState,
scheduler: Scheduler,
# The following parameters are exposed so that one may create, run, and throw
# away on the fly a specialized manager, separate from the main one.
policies: set[ActiveMemoryManagerPolicy] | None = None,
Expand Down Expand Up @@ -126,12 +126,14 @@ def run_once(self, comm=None) -> None:
# populate self.pending
self._run_policies()

drop_by_worker = defaultdict(set)
repl_by_worker = defaultdict(dict)
drop_by_worker: defaultdict[str, set[str]] = defaultdict(set)
repl_by_worker: defaultdict[str, dict[str, list[str]]] = defaultdict(dict)

for ts, (pending_repl, pending_drop) in self.pending.items():
if not ts.who_has:
continue
who_has = [ws_snd.address for ws_snd in ts.who_has - pending_drop]

assert who_has # Never drop the last replica
for ws_rec in pending_repl:
assert ws_rec not in ts.who_has
Expand All @@ -143,8 +145,8 @@ def run_once(self, comm=None) -> None:
# Fire-and-forget enact recommendations from policies
# This is temporary code, waiting for
# https://github.com/dask/distributed/pull/5046
for addr, who_has in repl_by_worker.items():
asyncio.create_task(self.scheduler.gather_on_worker(addr, who_has))
for addr, who_has_map in repl_by_worker.items():
asyncio.create_task(self.scheduler.gather_on_worker(addr, who_has_map))
for addr, keys in drop_by_worker.items():
asyncio.create_task(self.scheduler.delete_worker_data(addr, keys))
# End temporary code
Expand Down Expand Up @@ -215,7 +217,7 @@ def _find_recipient(
candidates -= pending_repl
if not candidates:
return None
return min(candidates, key=self.workers_memory.get)
return min(candidates, key=self.workers_memory.__getitem__)

def _find_dropper(
self,
Expand Down Expand Up @@ -244,7 +246,7 @@ def _find_dropper(
candidates -= {waiter_ts.processing_on for waiter_ts in ts.waiters}
if not candidates:
return None
return max(candidates, key=self.workers_memory.get)
return max(candidates, key=self.workers_memory.__getitem__)


class ActiveMemoryManagerPolicy:
Expand Down
19 changes: 9 additions & 10 deletions distributed/cli/tests/test_dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,22 @@ def test_scheduler_port_zero(loop):
with tmpfile() as fn:
with popen(
["dask-scheduler", "--no-dashboard", "--scheduler-file", fn, "--port", "0"]
) as sched:
):
with Client(scheduler_file=fn, loop=loop) as c:
assert c.scheduler.port
assert c.scheduler.port != 8786


def test_dashboard_port_zero(loop):
pytest.importorskip("bokeh")
with tmpfile() as fn:
with popen(["dask-scheduler", "--dashboard-address", ":0"]) as proc:
count = 0
while count < 1:
line = proc.stderr.readline()
if b"dashboard" in line.lower():
sleep(0.01)
count += 1
assert b":0" not in line
with popen(["dask-scheduler", "--dashboard-address", ":0"]) as proc:
count = 0
while count < 1:
line = proc.stderr.readline()
if b"dashboard" in line.lower():
sleep(0.01)
count += 1
assert b":0" not in line


PRELOAD_TEXT = """
Expand Down
26 changes: 15 additions & 11 deletions distributed/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import atexit
import copy
Expand All @@ -22,6 +24,7 @@
from functools import partial
from numbers import Number
from queue import Queue as pyQueue
from typing import ClassVar

from tlz import first, groupby, keymap, merge, partition_all, valmap

Expand Down Expand Up @@ -49,7 +52,7 @@
from tornado import gen
from tornado.ioloop import IOLoop, PeriodicCallback

from . import versions as version_module
from . import versions as version_module # type: ignore
from .batched import BatchedSend
from .cfexecutor import ClientExecutor
from .core import (
Expand Down Expand Up @@ -95,7 +98,9 @@

logger = logging.getLogger(__name__)

_global_clients = weakref.WeakValueDictionary()
_global_clients: weakref.WeakValueDictionary[
int, Client
] = weakref.WeakValueDictionary()
_global_client_index = [0]

_current_client = ContextVar("_current_client", default=None)
Expand All @@ -105,7 +110,7 @@
NO_DEFAULT_PLACEHOLDER = "_no_default_"


def _get_global_client():
def _get_global_client() -> Client | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself and others: 3.10 syntax, but backwards compatible with from __future__ import annotations, AFAIK.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. 3.10 syntax works on all python versions as long as your version of mypy is recent.
Notable exception: cast, since it is executed at runtime, needs Python 3.10 style annotations wrapped in a string.

L = sorted(list(_global_clients), reverse=True)
for k in L:
c = _global_clients[k]
Expand All @@ -116,13 +121,13 @@ def _get_global_client():
return None


def _set_global_client(c):
def _set_global_client(c: Client | None) -> None:
if c is not None:
_global_clients[_global_client_index[0]] = c
_global_client_index[0] += 1


def _del_global_client(c):
def _del_global_client(c: Client) -> None:
for k in list(_global_clients):
try:
if _global_clients[k] is c:
Expand Down Expand Up @@ -590,7 +595,7 @@ class Client:
distributed.LocalCluster:
"""

_instances = weakref.WeakSet()
_instances: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()

_default_event_handlers = {"print": _handle_print, "warn": _handle_warn}

Expand Down Expand Up @@ -1377,8 +1382,6 @@ async def _close(self, fast=False):

self.status = "closed"

_shutdown = _close

Comment on lines -1380 to -1381
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting to future readers that there's a _shutdown method defined a few lines below

def close(self, timeout=no_default):
"""Close this client

Expand Down Expand Up @@ -2529,12 +2532,13 @@ def _get_computation_code() -> str:
)
if not isinstance(ignore_modules, list):
raise TypeError(
f"Ignored modules must be a list. Instead got ({type(ignore_modules)}, {ignore_modules})"
"Ignored modules must be a list. Instead got "
f"({type(ignore_modules)}, {ignore_modules})"
)

pattern: re.Pattern | None
if ignore_modules:
pattern = "|".join([f"(?:{mod})" for mod in ignore_modules])
pattern = re.compile(pattern)
pattern = re.compile("|".join([f"(?:{mod})" for mod in ignore_modules]))
else:
pattern = None

Expand Down
28 changes: 17 additions & 11 deletions distributed/comm/addressing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import itertools

import dask
Expand All @@ -6,7 +8,7 @@
from . import registry


def parse_address(addr, strict=False):
def parse_address(addr: str, strict: bool = False) -> tuple[str, str]:
"""
Split address into its scheme and scheme-dependent location string.

Expand All @@ -30,7 +32,7 @@ def parse_address(addr, strict=False):
return scheme, loc


def unparse_address(scheme, loc):
def unparse_address(scheme: str, loc: str) -> str:
"""
Undo parse_address().

Expand All @@ -40,7 +42,7 @@ def unparse_address(scheme, loc):
return f"{scheme}://{loc}"


def normalize_address(addr):
def normalize_address(addr: str) -> str:
"""
Canonicalize address, adding a default scheme if necessary.

Expand All @@ -52,7 +54,9 @@ def normalize_address(addr):
return unparse_address(*parse_address(addr))


def parse_host_port(address, default_port=None):
def parse_host_port(
address: str | tuple[str, int], default_port: str | int | None = None
) -> tuple[str, int]:
"""
Parse an endpoint address given in the form "host:port".
"""
Expand Down Expand Up @@ -95,19 +99,19 @@ def _default():
return host, int(port)


def unparse_host_port(host, port=None):
def unparse_host_port(host: str, port: int | None = None) -> str:
"""
Undo parse_host_port().
"""
if ":" in host and not host.startswith("["):
host = "[%s]" % host
host = f"[{host}]"
if port is not None:
return f"{host}:{port}"
else:
return host


def get_address_host_port(addr, strict=False):
def get_address_host_port(addr: str, strict: bool = False) -> tuple[str, int]:
"""
Get a (host, port) tuple out of the given address.
For definition of strict check parse_address
Expand All @@ -129,7 +133,7 @@ def get_address_host_port(addr, strict=False):
)


def get_address_host(addr):
def get_address_host(addr: str) -> str:
"""
Return a hostname / IP address identifying the machine this address
is located on.
Expand All @@ -145,7 +149,7 @@ def get_address_host(addr):
return backend.get_address_host(loc)


def get_local_address_for(addr):
def get_local_address_for(addr: str) -> str:
"""
Get a local listening address suitable for reaching *addr*.

Expand All @@ -162,7 +166,7 @@ def get_local_address_for(addr):
return unparse_address(scheme, backend.get_local_address_for(loc))


def resolve_address(addr):
def resolve_address(addr: str) -> str:
"""
Apply scheme-specific address resolution to *addr*, replacing
all symbolic references with concrete location specifiers.
Expand All @@ -177,7 +181,9 @@ def resolve_address(addr):
return unparse_address(scheme, backend.resolve_address(loc))


def uri_from_host_port(host_arg, port_arg, default_port):
def uri_from_host_port(
host_arg: str | None, port_arg: str | None, default_port: int
) -> str:
"""
Process the *host* and *port* CLI options.
Return a URI.
Expand Down
Loading