Skip to content

Commit

Permalink
annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Sep 23, 2021
1 parent 2b2de7e commit 3de84a1
Show file tree
Hide file tree
Showing 26 changed files with 289 additions and 197 deletions.
27 changes: 16 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,23 @@ repos:
hooks:
- id: pyupgrade
args:
- "--py37-plus"
- --py37-plus
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910
hooks:
- id: mypy
additional_dependencies: [
# Type stubs
types-docutils,
types-requests,
types-paramiko,
types-pkg_resources,
types-PyYAML,
types-setuptools,
types-psutil,
]
additional_dependencies:
# Type stubs
- types-docutils
- types-requests
- types-paramiko
- types-pkg_resources
- types-PyYAML
- types-setuptools
- types-psutil
# Libraries exclusively imported under `if TYPE_CHECKING:`
- typing_extensions # To be reviewed after dropping Python 3.7
# Typed libraries
- numpy
- dask
- tornado
26 changes: 15 additions & 11 deletions distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import asyncio
from collections import defaultdict
from collections.abc import Generator
from typing import TYPE_CHECKING
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, cast

from tornado.ioloop import PeriodicCallback

Expand All @@ -13,7 +13,7 @@
from .utils import import_term

if TYPE_CHECKING:
from .scheduler import SchedulerState, TaskState, WorkerState
from .scheduler import Scheduler, TaskState, WorkerState


class ActiveMemoryManagerExtension:
Expand All @@ -31,7 +31,7 @@ class ActiveMemoryManagerExtension:
``distributed.scheduler.active-memory-manager``.
"""

scheduler: SchedulerState
scheduler: Scheduler
policies: set[ActiveMemoryManagerPolicy]
interval: float

Expand All @@ -43,7 +43,7 @@ class ActiveMemoryManagerExtension:

def __init__(
self,
scheduler: SchedulerState,
scheduler: Scheduler,
# The following parameters are exposed so that one may create, run, and throw
# away on the fly a specialized manager, separate from the main one.
policies: set[ActiveMemoryManagerPolicy] | None = None,
Expand Down Expand Up @@ -126,12 +126,14 @@ def run_once(self, comm=None) -> None:
# populate self.pending
self._run_policies()

drop_by_worker = defaultdict(set)
repl_by_worker = defaultdict(dict)
drop_by_worker: defaultdict[str, set[str]] = defaultdict(set)
repl_by_worker: defaultdict[str, dict[str, list[str]]] = defaultdict(dict)

for ts, (pending_repl, pending_drop) in self.pending.items():
if not ts.who_has:
continue
who_has = [ws_snd.address for ws_snd in ts.who_has - pending_drop]

assert who_has # Never drop the last replica
for ws_rec in pending_repl:
assert ws_rec not in ts.who_has
Expand All @@ -143,8 +145,8 @@ def run_once(self, comm=None) -> None:
# Fire-and-forget enact recommendations from policies
# This is temporary code, waiting for
# https://github.com/dask/distributed/pull/5046
for addr, who_has in repl_by_worker.items():
asyncio.create_task(self.scheduler.gather_on_worker(addr, who_has))
for addr, who_has_map in repl_by_worker.items():
asyncio.create_task(self.scheduler.gather_on_worker(addr, who_has_map))
for addr, keys in drop_by_worker.items():
asyncio.create_task(self.scheduler.delete_worker_data(addr, keys))
# End temporary code
Expand Down Expand Up @@ -215,7 +217,8 @@ def _find_recipient(
candidates -= pending_repl
if not candidates:
return None
return min(candidates, key=self.workers_memory.get)
key = cast(Callable[[WorkerState], int], self.workers_memory.get)
return min(candidates, key=key)

def _find_dropper(
self,
Expand Down Expand Up @@ -244,7 +247,8 @@ def _find_dropper(
candidates -= {waiter_ts.processing_on for waiter_ts in ts.waiters}
if not candidates:
return None
return max(candidates, key=self.workers_memory.get)
key = cast(Callable[[WorkerState], int], self.workers_memory.get)
return max(candidates, key=key)


class ActiveMemoryManagerPolicy:
Expand Down
26 changes: 15 additions & 11 deletions distributed/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import atexit
import copy
Expand All @@ -22,6 +24,7 @@
from functools import partial
from numbers import Number
from queue import Queue as pyQueue
from typing import ClassVar

from tlz import first, groupby, keymap, merge, partition_all, valmap

Expand Down Expand Up @@ -49,7 +52,7 @@
from tornado import gen
from tornado.ioloop import IOLoop, PeriodicCallback

from . import versions as version_module
from . import versions as version_module # type: ignore
from .batched import BatchedSend
from .cfexecutor import ClientExecutor
from .core import (
Expand Down Expand Up @@ -95,7 +98,9 @@

logger = logging.getLogger(__name__)

_global_clients = weakref.WeakValueDictionary()
_global_clients: weakref.WeakValueDictionary[
int, Client
] = weakref.WeakValueDictionary()
_global_client_index = [0]

_current_client = ContextVar("_current_client", default=None)
Expand All @@ -105,7 +110,7 @@
NO_DEFAULT_PLACEHOLDER = "_no_default_"


def _get_global_client():
def _get_global_client() -> Client | None:
L = sorted(list(_global_clients), reverse=True)
for k in L:
c = _global_clients[k]
Expand All @@ -116,13 +121,13 @@ def _get_global_client():
return None


def _set_global_client(c):
def _set_global_client(c: Client | None) -> None:
if c is not None:
_global_clients[_global_client_index[0]] = c
_global_client_index[0] += 1


def _del_global_client(c):
def _del_global_client(c: Client) -> None:
for k in list(_global_clients):
try:
if _global_clients[k] is c:
Expand Down Expand Up @@ -590,7 +595,7 @@ class Client:
distributed.LocalCluster:
"""

_instances = weakref.WeakSet()
_instances: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet()

_default_event_handlers = {"print": _handle_print, "warn": _handle_warn}

Expand Down Expand Up @@ -1377,8 +1382,6 @@ async def _close(self, fast=False):

self.status = "closed"

_shutdown = _close

def close(self, timeout=no_default):
"""Close this client
Expand Down Expand Up @@ -2529,12 +2532,13 @@ def _get_computation_code() -> str:
)
if not isinstance(ignore_modules, list):
raise TypeError(
f"Ignored modules must be a list. Instead got ({type(ignore_modules)}, {ignore_modules})"
"Ignored modules must be a list. Instead got "
f"({type(ignore_modules)}, {ignore_modules})"
)

pattern: re.Pattern | None
if ignore_modules:
pattern = "|".join([f"(?:{mod})" for mod in ignore_modules])
pattern = re.compile(pattern)
pattern = re.compile("|".join([f"(?:{mod})" for mod in ignore_modules]))
else:
pattern = None

Expand Down
6 changes: 4 additions & 2 deletions distributed/comm/addressing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import itertools

import dask
Expand All @@ -6,7 +8,7 @@
from . import registry


def parse_address(addr, strict=False):
def parse_address(addr: str, strict: bool = False) -> tuple[str, str]:
"""
Split address into its scheme and scheme-dependent location string.
Expand Down Expand Up @@ -145,7 +147,7 @@ def get_address_host(addr):
return backend.get_address_host(loc)


def get_local_address_for(addr):
def get_local_address_for(addr: str) -> str:
"""
Get a local listening address suitable for reaching *addr*.
Expand Down
34 changes: 16 additions & 18 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import random
import sys
import weakref
from abc import ABC, abstractmethod, abstractproperty
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import ClassVar

Expand Down Expand Up @@ -102,21 +102,17 @@ def abort(self):

@abstractmethod
def closed(self):
"""
Return whether the stream is closed.
"""
"""Return whether the stream is closed."""

@abstractproperty
def local_address(self):
"""
The local address. For logging and debugging purposes only.
"""
@property
@abstractmethod
def local_address(self) -> str:
"""The local address. For logging and debugging purposes only."""

@abstractproperty
def peer_address(self):
"""
The peer's address. For logging and debugging purposes only.
"""
@property
@abstractmethod
def peer_address(self) -> str:
"""The peer's address. For logging and debugging purposes only."""

@property
def extra_info(self):
Expand Down Expand Up @@ -181,13 +177,15 @@ def stop(self):
communications, but prevents accepting new ones.
"""

@abstractproperty
@property
@abstractmethod
def listen_address(self):
"""
The listening address as a URI string.
"""

@abstractproperty
@property
@abstractmethod
def contact_address(self):
"""
An address this listener can be contacted on. This can be
Expand Down Expand Up @@ -230,9 +228,9 @@ async def on_connection(self, comm: Comm, handshake_overrides=None):
raise CommClosedError(f"Comm {comm!r} closed.") from e

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
comm.remote_info["address"] = comm.peer_address
comm.local_info = local_info
comm.local_info["address"] = comm._local_addr
comm.local_info["address"] = comm.local_address

comm.handshake_options = comm.handshake_configuration(
comm.local_info, comm.remote_info
Expand Down
12 changes: 9 additions & 3 deletions distributed/comm/inproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ class InProc(Comm):
_initialized = False

def __init__(
self, local_addr, peer_addr, read_q, write_q, write_loop, deserialize=True
self,
local_addr: str,
peer_addr: str,
read_q,
write_q,
write_loop,
deserialize=True,
):
super().__init__()
self._local_addr = local_addr
Expand All @@ -176,11 +182,11 @@ def finalize(write_q=self._write_q, write_loop=self._write_loop, r=repr(self)):
return finalize

@property
def local_address(self):
def local_address(self) -> str:
return self._local_addr

@property
def peer_address(self):
def peer_address(self) -> str:
return self._peer_addr

async def read(self, deserializers="ignored"):
Expand Down
21 changes: 14 additions & 7 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ class TCP(Comm):

max_shard_size = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard"))

def __init__(self, stream, local_addr, peer_addr, deserialize=True):
def __init__(
self,
stream,
local_addr: str,
peer_addr: str,
deserialize: bool = True,
):
self._closed = False
super().__init__()
self._local_addr = local_addr
Expand All @@ -156,7 +162,7 @@ def __init__(self, stream, local_addr, peer_addr, deserialize=True):
self.deserialize = deserialize
self._finalizer = weakref.finalize(self, self._get_finalizer())
self._finalizer.atexit = False
self._extra = {}
self._extra: dict = {}

ref = weakref.ref(self)

Expand All @@ -171,19 +177,20 @@ def _read_extra(self):

def _get_finalizer(self):
def finalize(stream=self.stream, r=repr(self)):
# stream is None if a StreamClosedError is raised during interpreter shutdown
# stream is None if a StreamClosedError is raised during interpreter
# shutdown
if stream is not None and not stream.closed():
logger.warning(f"Closing dangling stream in {r}")
stream.close()

return finalize

@property
def local_address(self):
def local_address(self) -> str:
return self._local_addr

@property
def peer_address(self):
def peer_address(self) -> str:
return self._peer_addr

async def read(self, deserializers=None):
Expand Down Expand Up @@ -391,8 +398,8 @@ async def connect(self, address, deserialize=True, **connection_args):
stream = await self.client.connect(
ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs
)
# Under certain circumstances tornado will have a closed connnection with an error and not raise
# a StreamClosedError.
# Under certain circumstances tornado will have a closed connnection with an
# error and not raise a StreamClosedError.
#
# This occurs with tornado 5.x and openssl 1.1+
if stream.closed() and stream.error:
Expand Down
Loading

0 comments on commit 3de84a1

Please sign in to comment.