From cf1f3c745a6966241ab967f43f73ecff632a337a Mon Sep 17 00:00:00 2001 From: John Litborn <11260241+jakkdl@users.noreply.github.com> Date: Sat, 29 Jul 2023 10:37:24 +0000 Subject: [PATCH] typecheck _socket and _core._local (#2705) * add type hints to _socket and _core.local --------- Co-authored-by: Spencer Brown Co-authored-by: EXPLOSION --- docs/source/conf.py | 2 + docs/source/reference-io.rst | 8 + pyproject.toml | 15 +- trio/_core/_local.py | 66 ++--- trio/_socket.py | 447 ++++++++++++++++++++++++++-------- trio/_sync.py | 19 +- trio/_tests/test_socket.py | 24 +- trio/_tests/verify_types.json | 38 +-- trio/_threads.py | 24 +- trio/socket.py | 1 + 10 files changed, 455 insertions(+), 189 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 650688717a..0e16b2d426 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,6 +62,8 @@ ("py:obj", "trio._abc.SendType"), ("py:obj", "trio._abc.T"), ("py:obj", "trio._abc.T_resource"), + ("py:class", "trio._threads.T"), + # why aren't these found in stdlib? ("py:class", "types.FrameType"), ("py:class", "P.args"), ("py:class", "P.kwargs"), diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 85969174aa..9207afb41b 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -504,6 +504,14 @@ Socket objects * :meth:`~socket.socket.set_inheritable` * :meth:`~socket.socket.get_inheritable` +The internal SocketType +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: _SocketType +.. + TODO: adding `:members:` here gives error due to overload+_wraps on `sendto` + TODO: rewrite ... all of the above when fixing _SocketType vs SocketType + + .. currentmodule:: trio diff --git a/pyproject.toml b/pyproject.toml index d479442c7a..445c40e28c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,20 +48,25 @@ disallow_untyped_defs = false # downstream and users have to deal with them. [[tool.mypy.overrides]] module = [ - "trio._path", + "trio._socket", + "trio._core._local", + "trio._sync", "trio._file_io", ] +disallow_incomplete_defs = true disallow_untyped_defs = true +disallow_any_generics = true +disallow_any_decorated = true +disallow_subclassing_any = true [[tool.mypy.overrides]] module = [ - "trio._dtls", - "trio._abc" + "trio._path", ] disallow_incomplete_defs = true disallow_untyped_defs = true -disallow_any_generics = true -disallow_any_decorated = true +#disallow_any_generics = true +#disallow_any_decorated = true disallow_subclassing_any = true [tool.pytest.ini_options] diff --git a/trio/_core/_local.py b/trio/_core/_local.py index a54f424fdf..7f2c632153 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,25 +1,34 @@ +from __future__ import annotations + +from typing import Generic, TypeVar, final + # Runvar implementations import attr -from .._util import Final +from .._util import Final, NoPublicConstructor from . import _run +T = TypeVar("T") + + +@final +class _NoValue(metaclass=Final): + ... -@attr.s(eq=False, hash=False, slots=True) -class _RunVarToken: - _no_value = object() - _var = attr.ib() - previous_value = attr.ib(default=_no_value) - redeemed = attr.ib(default=False, init=False) +@attr.s(eq=False, hash=False, slots=False) +class RunVarToken(Generic[T], metaclass=NoPublicConstructor): + _var: RunVar[T] = attr.ib() + previous_value: T | type[_NoValue] = attr.ib(default=_NoValue) + redeemed: bool = attr.ib(default=False, init=False) @classmethod - def empty(cls, var): - return cls(var) + def _empty(cls, var: RunVar[T]) -> RunVarToken[T]: + return cls._create(var) @attr.s(eq=False, hash=False, slots=True) -class RunVar(metaclass=Final): +class RunVar(Generic[T], metaclass=Final): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, @@ -28,27 +37,28 @@ class RunVar(metaclass=Final): """ - _NO_DEFAULT = object() - _name = attr.ib() - _default = attr.ib(default=_NO_DEFAULT) + _name: str = attr.ib() + _default: T | type[_NoValue] = attr.ib(default=_NoValue) - def get(self, default=_NO_DEFAULT): + def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" try: - return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] + # not typed yet + return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: # contextvars consistency - if default is not self._NO_DEFAULT: - return default + # `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released + if default is not _NoValue: + return default # type: ignore[return-value] - if self._default is not self._NO_DEFAULT: - return self._default + if self._default is not _NoValue: + return self._default # type: ignore[return-value] raise LookupError(self) from None - def set(self, value): + def set(self, value: T) -> RunVarToken[T]: """Sets the value of this :class:`RunVar` for this current run call. @@ -56,16 +66,16 @@ def set(self, value): try: old_value = self.get() except LookupError: - token = _RunVarToken.empty(self) + token = RunVarToken._empty(self) else: - token = _RunVarToken(self, old_value) + token = RunVarToken[T]._create(self, old_value) # This can't fail, because if we weren't in Trio context then the # get() above would have failed. - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index] return token - def reset(self, token): + def reset(self, token: RunVarToken[T]) -> None: """Resets the value of this :class:`RunVar` to what it was previously specified by the token. @@ -81,14 +91,14 @@ def reset(self, token): previous = token.previous_value try: - if previous is _RunVarToken._no_value: - _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) + if previous is _NoValue: + _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type] else: - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") token.redeemed = True - def __repr__(self): + def __repr__(self) -> str: return f"" diff --git a/trio/_socket.py b/trio/_socket.py index 26b03fc3e0..b0ec1d480d 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -5,7 +5,20 @@ import socket as _stdlib_socket import sys from functools import wraps as _wraps -from typing import TYPE_CHECKING, Tuple, Union +from operator import index +from socket import AddressFamily, SocketKind +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + NoReturn, + SupportsIndex, + Tuple, + TypeVar, + Union, + overload, +) import idna as _idna @@ -17,7 +30,14 @@ from collections.abc import Iterable from types import TracebackType - from typing_extensions import Self, TypeAlias + from typing_extensions import Buffer, Concatenate, ParamSpec, Self, TypeAlias + + from ._abc import HostnameResolver, SocketFactory + + P = ParamSpec("P") + + +T = TypeVar("T") # must use old-style typing because it's evaluated at runtime Address: TypeAlias = Union[ @@ -34,16 +54,18 @@ # return await do_it_properly_with_a_check_point() # class _try_sync: - def __init__(self, blocking_exc_override=None): + def __init__( + self, blocking_exc_override: Callable[[BaseException], bool] | None = None + ): self._blocking_exc_override = blocking_exc_override - def _is_blocking_io_error(self, exc): + def _is_blocking_io_error(self, exc: BaseException) -> bool: if self._blocking_exc_override is None: return isinstance(exc, BlockingIOError) else: return self._blocking_exc_override(exc) - async def __aenter__(self): + async def __aenter__(self) -> None: await trio.lowlevel.checkpoint_if_cancelled() async def __aexit__( @@ -66,11 +88,13 @@ async def __aexit__( # Overrides ################################################################ -_resolver = _core.RunVar("hostname_resolver") -_socket_factory = _core.RunVar("socket_factory") +_resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver") +_socket_factory: _core.RunVar[SocketFactory | None] = _core.RunVar("socket_factory") -def set_custom_hostname_resolver(hostname_resolver): +def set_custom_hostname_resolver( + hostname_resolver: HostnameResolver | None, +) -> HostnameResolver | None: """Set a custom hostname resolver. By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions @@ -102,7 +126,9 @@ def set_custom_hostname_resolver(hostname_resolver): return old -def set_custom_socket_factory(socket_factory): +def set_custom_socket_factory( + socket_factory: SocketFactory | None, +) -> SocketFactory | None: """Set a custom socket object factory. This function allows you to replace Trio's normal socket class with a @@ -136,7 +162,23 @@ def set_custom_socket_factory(socket_factory): _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV -async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): +# It would be possible to @overload the return value depending on Literal[AddressFamily.INET/6], but should probably be added in typeshed first +async def getaddrinfo( + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] +]: """Look up a numeric address given a name. Arguments and return values are identical to :func:`socket.getaddrinfo`, @@ -157,7 +199,7 @@ async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): # skip the whole thread thing, which seems worthwhile. So we try first # with the _NUMERIC_ONLY flags set, and then only spawn a thread if that # fails with EAI_NONAME: - def numeric_only_failure(exc): + def numeric_only_failure(exc: BaseException) -> bool: return ( isinstance(exc, _stdlib_socket.gaierror) and exc.errno == _stdlib_socket.EAI_NONAME @@ -199,7 +241,9 @@ def numeric_only_failure(exc): ) -async def getnameinfo(sockaddr, flags): +async def getnameinfo( + sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int +) -> tuple[str, str]: """Look up a name given a numeric address. Arguments and return values are identical to :func:`socket.getnameinfo`, @@ -218,7 +262,7 @@ async def getnameinfo(sockaddr, flags): ) -async def getprotobyname(name): +async def getprotobyname(name: str) -> int: """Look up a protocol number by name. (Rarely used.) Like :func:`socket.getprotobyname`, but async. @@ -237,7 +281,7 @@ async def getprotobyname(name): ################################################################ -def from_stdlib_socket(sock): +def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType: """Convert a standard library :class:`socket.socket` object into a Trio socket object. @@ -246,9 +290,14 @@ def from_stdlib_socket(sock): @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) -def fromfd(fd, family, type, proto=0): +def fromfd( + fd: SupportsIndex, + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, +) -> _SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" - family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd) + family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) @@ -257,27 +306,41 @@ def fromfd(fd, family, type, proto=0): ): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(*args, **kwargs): - return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs)) + def fromshare(info: bytes) -> _SocketType: + return from_stdlib_socket(_stdlib_socket.fromshare(info)) + + +if sys.platform == "win32": + FamilyT: TypeAlias = int + TypeT: TypeAlias = int + FamilyDefault = _stdlib_socket.AF_INET +else: + FamilyDefault = None + FamilyT: TypeAlias = Union[int, AddressFamily, None] + TypeT: TypeAlias = Union[_stdlib_socket.socket, int] @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) -def socketpair(*args, **kwargs): +def socketpair( + family: FamilyT = FamilyDefault, + type: TypeT = SocketKind.SOCK_STREAM, + proto: int = 0, +) -> tuple[_SocketType, _SocketType]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. """ - left, right = _stdlib_socket.socketpair(*args, **kwargs) + left, right = _stdlib_socket.socketpair(family, type, proto) return (from_stdlib_socket(left), from_stdlib_socket(right)) @_wraps(_stdlib_socket.socket, assigned=(), updated=()) def socket( - family=_stdlib_socket.AF_INET, - type=_stdlib_socket.SOCK_STREAM, - proto=0, - fileno=None, -): + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, +) -> _SocketType: """Create a new Trio socket, like :class:`socket.socket`. This function's behavior can be customized using @@ -294,14 +357,24 @@ def socket( return from_stdlib_socket(stdlib_socket) -def _sniff_sockopts_for_fileno(family, type, proto, fileno): +def _sniff_sockopts_for_fileno( + family: AddressFamily | int, + type: SocketKind | int, + proto: int, + fileno: int | None, +) -> tuple[AddressFamily | int, SocketKind | int, int]: """Correct SOCKOPTS for given fileno, falling back to provided values.""" # Wrap the raw fileno into a Python socket object # This object might have the wrong metadata, but it lets us easily call getsockopt # and then we'll throw it away and construct a new one with the correct metadata. if sys.platform != "linux": return family, type, proto - from socket import SO_DOMAIN, SO_PROTOCOL, SO_TYPE, SOL_SOCKET + from socket import ( # type: ignore[attr-defined] + SO_DOMAIN, + SO_PROTOCOL, + SO_TYPE, + SOL_SOCKET, + ) sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno) try: @@ -331,19 +404,21 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): ) -def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False): - fn = getattr(_stdlib_socket.socket, methname) - +def _make_simple_sock_method_wrapper( + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], + maybe_avail: bool = False, +) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]: @_wraps(fn, assigned=("__name__",), updated=()) - async def wrapper(self, *args, **kwargs): - return await self._nonblocking_helper(fn, args, kwargs, wait_fn) + async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T: + return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs) - wrapper.__doc__ = f"""Like :meth:`socket.socket.{methname}`, but async. + wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async. """ if maybe_avail: wrapper.__doc__ += ( - f"Only available on platforms where :meth:`socket.socket.{methname}` is " + f"Only available on platforms where :meth:`socket.socket.{fn.__name__}` is " "available." ) return wrapper @@ -362,8 +437,21 @@ async def wrapper(self, *args, **kwargs): # local=False means that the address is being used with connect() or sendto() or # similar. # + + +# Using a TypeVar to indicate we return the same type of address appears to give errors +# when passed a union of address types. +# @overload likely works, but is extremely verbose. # NOTE: this function does not always checkpoint -async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, local): +async def _resolve_address_nocp( + type: int, + family: AddressFamily, + proto: int, + *, + ipv6_v6only: bool | int, + address: Address, + local: bool, +) -> Address: # Do some pre-checking (or exit early for non-IP sockets) if family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -373,13 +461,15 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo raise ValueError( "address should be a (host, port, [flowinfo, [scopeid]]) tuple" ) - elif family == _stdlib_socket.AF_UNIX: + elif family == getattr(_stdlib_socket, "AF_UNIX"): # unwrap path-likes + assert isinstance(address, (str, bytes)) return os.fspath(address) else: return address # -- From here on we know we have IPv4 or IPV6 -- + host: str | None host, port, *_ = address # Fast path for the simple case: already-resolved IP address, # already-resolved port. This is particularly important for UDP, since @@ -417,18 +507,24 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo # The above ignored any flowid and scopeid in the passed-in address, # so restore them if present: if family == _stdlib_socket.AF_INET6: - normed = list(normed) + list_normed = list(normed) assert len(normed) == 4 + # typechecking certainly doesn't like this logic, but given just how broad + # Address is, it's quite cumbersome to write the below without type: ignore if len(address) >= 3: - normed[2] = address[2] + list_normed[2] = address[2] # type: ignore if len(address) >= 4: - normed[3] = address[3] - normed = tuple(normed) + list_normed[3] = address[3] # type: ignore + return tuple(list_normed) # type: ignore return normed +# TODO: stopping users from initializing this type should be done in a different way, +# so SocketType can be used as a type. Note that this is *far* from trivial without +# breaking subclasses of SocketType. Can maybe add abstract methods to SocketType, +# or rename _SocketType. class SocketType: - def __init__(self): + def __init__(self) -> NoReturn: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" @@ -451,36 +547,80 @@ def __init__(self, sock: _stdlib_socket.socket): # Simple + portable methods and attributes ################################################################ - # NB this doesn't work because for loops don't create a scope - # for _name in [ - # ]: - # _meth = getattr(_stdlib_socket.socket, _name) - # @_wraps(_meth, assigned=("__name__", "__doc__"), updated=()) - # def _wrapped(self, *args, **kwargs): - # return getattr(self._sock, _meth)(*args, **kwargs) - # locals()[_meth] = _wrapped - # del _name, _meth, _wrapped - - _forward = { - "detach", - "get_inheritable", - "set_inheritable", - "fileno", - "getpeername", - "getsockname", - "getsockopt", - "setsockopt", - "listen", - "share", - } - - def __getattr__(self, name): - if name in self._forward: - return getattr(self._sock, name) - raise AttributeError(name) - - def __dir__(self) -> Iterable[str]: - return [*super().__dir__(), *self._forward] + # forwarded methods + def detach(self) -> int: + return self._sock.detach() + + def fileno(self) -> int: + return self._sock.fileno() + + def getpeername(self) -> Any: + return self._sock.getpeername() + + def getsockname(self) -> Any: + return self._sock.getsockname() + + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + if buflen is None: + return self._sock.getsockopt(level, optname) + return self._sock.getsockopt(level, optname, buflen) + + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + if optlen is None: + if value is None: + raise TypeError( + "invalid value for argument 'value', must not be None when specifying optlen" + ) + return self._sock.setsockopt(level, optname, value) + if value is not None: + raise TypeError( + "invalid value for argument 'value': {value!r}, must be None when specifying optlen" + ) + + # Note: PyPy may crash here due to setsockopt only supporting + # four parameters. + return self._sock.setsockopt(level, optname, value, optlen) + + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + return self._sock.listen(backlog) + + def get_inheritable(self) -> bool: + return self._sock.get_inheritable() + + def set_inheritable(self, inheritable: bool) -> None: + return self._sock.set_inheritable(inheritable) + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, /, process_id: int) -> bytes: + return self._sock.share(process_id) def __enter__(self) -> Self: return self @@ -494,11 +634,11 @@ def __exit__( return self._sock.__exit__(exc_type, exc_value, traceback) @property - def family(self) -> _stdlib_socket.AddressFamily: + def family(self) -> AddressFamily: return self._sock.family @property - def type(self) -> _stdlib_socket.SocketKind: + def type(self) -> SocketKind: return self._sock.type @property @@ -521,7 +661,7 @@ def close(self) -> None: trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address: tuple[object, ...] | str | bytes) -> None: + async def bind(self, address: Address) -> None: address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") @@ -530,8 +670,7 @@ async def bind(self, address: tuple[object, ...] | str | bytes) -> None: ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) - # remove the `type: ignore` when run.sync is typed. - return await trio.to_thread.run_sync(self._sock.bind, address) # type: ignore[no-any-return] + return await trio.to_thread.run_sync(self._sock.bind, address) else: # POSIX actually says that bind can return EWOULDBLOCK and # complete asynchronously, like connect. But in practice AFAICT @@ -559,7 +698,12 @@ def is_readable(self) -> bool: async def wait_writable(self) -> None: await _core.wait_writable(self._sock) - async def _resolve_address_nocp(self, address, *, local): + async def _resolve_address_nocp( + self, + address: Address, + *, + local: bool, + ) -> Address: if self.family == _stdlib_socket.AF_INET6: ipv6_v6only = self._sock.getsockopt( _stdlib_socket.IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY @@ -575,7 +719,19 @@ async def _resolve_address_nocp(self, address, *, local): local=local, ) - async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): + # args and kwargs must be starred, otherwise pyright complains: + # '"args" member of ParamSpec is valid only when used with *args parameter' + # '"kwargs" member of ParamSpec is valid only when used with **kwargs parameter' + # wait_fn and fn must also be first in the signature + # 'Keyword parameter cannot appear in signature after ParamSpec args parameter' + + async def _nonblocking_helper( + self, + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: # We have to reconcile two conflicting goals: # - We want to make it look like we always blocked in doing these # operations. The obvious way is to always do an IO wait before @@ -611,9 +767,11 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): # accept ################################################################ - _accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable) + _accept = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.accept, _core.wait_readable + ) - async def accept(self): + async def accept(self) -> tuple[_SocketType, object]: """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -622,7 +780,7 @@ async def accept(self): # connect ################################################################ - async def connect(self, address): + async def connect(self, address: Address) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the @@ -690,38 +848,71 @@ async def connect(self, address): # Okay, the connect finished, but it might have failed: err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR) if err != 0: - raise OSError(err, f"Error connecting to {address}: {os.strerror(err)}") + raise OSError(err, f"Error connecting to {address!r}: {os.strerror(err)}") ################################################################ # recv ################################################################ + # Not possible to typecheck with a Callable (due to DefaultArg), nor with a + # callback Protocol (https://github.com/python/typing/discussions/1040) + # but this seems to work. If not explicitly defined then pyright --verifytypes will + # complain about AmbiguousType if TYPE_CHECKING: - async def recv(self, buffersize: int, flags: int = 0) -> bytes: + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: ... - else: - recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable) + # _make_simple_sock_method_wrapper is typed, so this checks that the above is correct + # this requires that we refrain from using `/` to specify pos-only + # args, or mypy thinks the signature differs from typeshed. + recv = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recv, _core.wait_readable + ) ################################################################ # recv_into ################################################################ - recv_into = _make_simple_sock_method_wrapper("recv_into", _core.wait_readable) + if TYPE_CHECKING: + + def recv_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: + ... + + recv_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recv_into, _core.wait_readable + ) ################################################################ # recvfrom ################################################################ - recvfrom = _make_simple_sock_method_wrapper("recvfrom", _core.wait_readable) + if TYPE_CHECKING: + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + __self, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, Address]]: + ... + + recvfrom = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvfrom, _core.wait_readable + ) ################################################################ # recvfrom_into ################################################################ - recvfrom_into = _make_simple_sock_method_wrapper( - "recvfrom_into", _core.wait_readable + if TYPE_CHECKING: + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, Address]]: + ... + + recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvfrom_into, _core.wait_readable ) ################################################################ @@ -729,8 +920,15 @@ async def recv(self, buffersize: int, flags: int = 0) -> bytes: ################################################################ if hasattr(_stdlib_socket.socket, "recvmsg"): - recvmsg = _make_simple_sock_method_wrapper( - "recvmsg", _core.wait_readable, maybe_avail=True + if TYPE_CHECKING: + + def recvmsg( + __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0 + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + ... + + recvmsg = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvmsg, _core.wait_readable, maybe_avail=True ) ################################################################ @@ -738,29 +936,58 @@ async def recv(self, buffersize: int, flags: int = 0) -> bytes: ################################################################ if hasattr(_stdlib_socket.socket, "recvmsg_into"): - recvmsg_into = _make_simple_sock_method_wrapper( - "recvmsg_into", _core.wait_readable, maybe_avail=True + if TYPE_CHECKING: + + def recvmsg_into( + __self, + __buffers: Iterable[Buffer], + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + ... + + recvmsg_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvmsg_into, _core.wait_readable, maybe_avail=True ) ################################################################ # send ################################################################ - send = _make_simple_sock_method_wrapper("send", _core.wait_writable) + if TYPE_CHECKING: + + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + ... + + send = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.send, _core.wait_writable + ) ################################################################ # sendto ################################################################ - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) - async def sendto(self, *args): + @overload + async def sendto( + self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @overload + async def sendto( + self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] + async def sendto(self, *args: Any) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address) # and kwargs are not accepted - args = list(args) - args[-1] = await self._resolve_address_nocp(args[-1], local=False) + args_list = list(args) + args_list[-1] = await self._resolve_address_nocp(args[-1], local=False) return await self._nonblocking_helper( - _stdlib_socket.socket.sendto, args, {}, _core.wait_writable + _core.wait_writable, _stdlib_socket.socket.sendto, *args_list ) ################################################################ @@ -772,20 +999,28 @@ async def sendto(self, *args): ): @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) - async def sendmsg(self, *args): + async def sendmsg( + self, + __buffers: Iterable[Buffer], + __ancdata: Iterable[tuple[int, int, Buffer]] = (), + __flags: int = 0, + __address: Address | None = None, + ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. Only available on platforms where :meth:`socket.socket.sendmsg` is available. """ - # args is: buffers[, ancdata[, flags[, address]]] - # and kwargs are not accepted - if len(args) == 4 and args[-1] is not None: - args = list(args) - args[-1] = await self._resolve_address_nocp(args[-1], local=False) + if __address is not None: + __address = await self._resolve_address_nocp(__address, local=False) return await self._nonblocking_helper( - _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable + _core.wait_writable, + _stdlib_socket.socket.sendmsg, + __buffers, + __ancdata, + __flags, + __address, ) ################################################################ diff --git a/trio/_sync.py b/trio/_sync.py index 5a7f240d5e..bd2122858e 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -8,7 +8,7 @@ import trio from . import _core -from ._core import ParkingLot, enable_ki_protection +from ._core import Abort, ParkingLot, RaiseCancelT, enable_ki_protection from ._util import Final if TYPE_CHECKING: @@ -87,7 +87,7 @@ async def wait(self) -> None: task = _core.current_task() self._tasks.add(task) - def abort_fn(_): + def abort_fn(_: RaiseCancelT) -> Abort: self._tasks.remove(task) return _core.Abort.SUCCEEDED @@ -143,10 +143,13 @@ class CapacityLimiterStatistics: borrowed_tokens: int = attr.ib() total_tokens: int | float = attr.ib() - borrowers: list[Task] = attr.ib() + borrowers: list[Task | object] = attr.ib() tasks_waiting: int = attr.ib() +# Can be a generic type with a default of Task if/when PEP 696 is released +# and implemented in type checkers. Making it fully generic would currently +# introduce a lot of unnecessary hassle. class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): """An object for controlling access to a resource with limited capacity. @@ -204,9 +207,9 @@ class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing def __init__(self, total_tokens: int | float): self._lot = ParkingLot() - self._borrowers: set[Task] = set() + self._borrowers: set[Task | object] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of - self._pending_borrowers: dict[Task, Task] = {} + self._pending_borrowers: dict[Task, Task | object] = {} # invoke the property setter for validation self.total_tokens: int | float = total_tokens assert self._total_tokens == total_tokens @@ -268,7 +271,7 @@ def acquire_nowait(self) -> None: self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower: Task) -> None: + def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. @@ -307,7 +310,7 @@ async def acquire(self) -> None: await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower: Task) -> None: + async def acquire_on_behalf_of(self, borrower: Task | object) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. @@ -347,7 +350,7 @@ def release(self) -> None: self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower: Task) -> None: + def release_on_behalf_of(self, borrower: Task | object) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: diff --git a/trio/_tests/test_socket.py b/trio/_tests/test_socket.py index e559b98240..e9baff436a 100644 --- a/trio/_tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -2,7 +2,7 @@ import inspect import os import socket as stdlib_socket -import sys as _sys +import sys import tempfile import attr @@ -277,7 +277,7 @@ async def test_socket_v6(): assert s.family == tsocket.AF_INET6 -@pytest.mark.skipif(not _sys.platform == "linux", reason="linux only") +@pytest.mark.skipif(not sys.platform == "linux", reason="linux only") async def test_sniff_sockopts(): from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM @@ -360,6 +360,26 @@ async def test_SocketType_basics(): sock.close() +async def test_SocketType_setsockopt(): + sock = tsocket.socket() + with sock as _: + # specifying optlen. Not supported on pypy, and I couldn't find + # valid calls on darwin or win32. + if hasattr(tsocket, "SO_BINDTODEVICE"): + sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0) + + # specifying value + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) + + # specifying both + with pytest.raises(TypeError, match="invalid value for argument 'value'"): + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) # type: ignore[call-overload] + + # specifying neither + with pytest.raises(TypeError, match="invalid value for argument 'value'"): + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) # type: ignore[call-overload] + + async def test_SocketType_dup(): a, b = tsocket.socketpair() with a, b: diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index d08c03060c..60132e07fd 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.888, + "completenessScore": 0.9072, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 555, - "withUnknownType": 69 + "withKnownType": 567, + "withUnknownType": 57 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -46,18 +46,13 @@ ], "otherSymbolCounts": { "withAmbiguousType": 3, - "withKnownType": 529, - "withUnknownType": 102 + "withKnownType": 574, + "withUnknownType": 76 }, "packageName": "trio", "symbols": [ "trio.__deprecated_attributes__", - "trio._abc.SocketFactory.socket", "trio._core._entry_queue.TrioToken.run_sync_soon", - "trio._core._local.RunVar.__repr__", - "trio._core._local.RunVar.get", - "trio._core._local.RunVar.reset", - "trio._core._local.RunVar.set", "trio._core._mock_clock.MockClock.jump", "trio._core._run.Nursery.start", "trio._core._run.Nursery.start_soon", @@ -72,24 +67,10 @@ "trio._core._unbounded_queue.UnboundedQueue.qsize", "trio._core._unbounded_queue.UnboundedQueue.statistics", "trio._dtls.DTLSChannel.__init__", - "trio._dtls.DTLSEndpoint.__init__", "trio._dtls.DTLSEndpoint.serve", - "trio._highlevel_socket.SocketListener.__init__", - "trio._highlevel_socket.SocketStream.__init__", "trio._highlevel_socket.SocketStream.getsockopt", "trio._highlevel_socket.SocketStream.send_all", "trio._highlevel_socket.SocketStream.setsockopt", - "trio._socket._SocketType.__getattr__", - "trio._socket._SocketType.accept", - "trio._socket._SocketType.connect", - "trio._socket._SocketType.recv_into", - "trio._socket._SocketType.recvfrom", - "trio._socket._SocketType.recvfrom_into", - "trio._socket._SocketType.recvmsg", - "trio._socket._SocketType.recvmsg_into", - "trio._socket._SocketType.send", - "trio._socket._SocketType.sendmsg", - "trio._socket._SocketType.sendto", "trio._ssl.SSLListener.__init__", "trio._ssl.SSLListener.accept", "trio._ssl.SSLListener.aclose", @@ -148,15 +129,6 @@ "trio.serve_listeners", "trio.serve_ssl_over_tcp", "trio.serve_tcp", - "trio.socket.from_stdlib_socket", - "trio.socket.fromfd", - "trio.socket.getaddrinfo", - "trio.socket.getnameinfo", - "trio.socket.getprotobyname", - "trio.socket.set_custom_hostname_resolver", - "trio.socket.set_custom_socket_factory", - "trio.socket.socket", - "trio.socket.socketpair", "trio.testing._memory_streams.MemoryReceiveStream.__init__", "trio.testing._memory_streams.MemoryReceiveStream.aclose", "trio.testing._memory_streams.MemoryReceiveStream.close", diff --git a/trio/_threads.py b/trio/_threads.py index 807212e0f9..3fbab05750 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -1,16 +1,19 @@ +from __future__ import annotations + import contextvars import functools import inspect import queue as stdlib_queue import threading from itertools import count -from typing import Optional +from typing import Any, Callable, Optional, TypeVar import attr import outcome from sniffio import current_async_library_cvar import trio +from trio._core._traps import RaiseCancelT from ._core import ( RunVar, @@ -22,10 +25,12 @@ from ._sync import CapacityLimiter from ._util import coroutine_or_error +T = TypeVar("T") + # Global due to Threading API, thread local storage for trio token TOKEN_LOCAL = threading.local() -_limiter_local = RunVar("limiter") +_limiter_local: RunVar[CapacityLimiter] = RunVar("limiter") # I pulled this number out of the air; it isn't based on anything. Probably we # should make some kind of measurements to pick a good value. DEFAULT_LIMIT = 40 @@ -59,8 +64,12 @@ class ThreadPlaceholder: @enable_ki_protection async def to_thread_run_sync( - sync_fn, *args, thread_name: Optional[str] = None, cancellable=False, limiter=None -): + sync_fn: Callable[..., T], + *args: Any, + thread_name: Optional[str] = None, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T: """Convert a blocking operation into an async operation using a thread. These two lines are equivalent:: @@ -152,7 +161,7 @@ async def to_thread_run_sync( # Holds a reference to the task that's blocked in this function waiting # for the result – or None if this function was cancelled and we should # discard the result. - task_register = [trio.lowlevel.current_task()] + task_register: list[trio.lowlevel.Task | None] = [trio.lowlevel.current_task()] name = f"trio.to_thread.run_sync-{next(_thread_counter)}" placeholder = ThreadPlaceholder(name) @@ -217,14 +226,15 @@ def deliver_worker_fn_result(result): limiter.release_on_behalf_of(placeholder) raise - def abort(_): + def abort(_: RaiseCancelT) -> trio.lowlevel.Abort: if cancellable: task_register[0] = None return trio.lowlevel.Abort.SUCCEEDED else: return trio.lowlevel.Abort.FAILED - return await trio.lowlevel.wait_task_rescheduled(abort) + # wait_task_rescheduled return value cannot be typed + return await trio.lowlevel.wait_task_rescheduled(abort) # type: ignore[no-any-return] def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None): diff --git a/trio/socket.py b/trio/socket.py index a9e276c782..f6aebb6a6e 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -35,6 +35,7 @@ # import the overwrites from ._socket import ( SocketType as SocketType, + _SocketType as _SocketType, from_stdlib_socket as from_stdlib_socket, fromfd as fromfd, getaddrinfo as getaddrinfo,