Skip to content

Commit 24226c3

Browse files
jakkdlCoolCat467
andauthored
Typing SocketType, test_socket, test_highlevel_[socket, open_tcp_stream, open_tcp_listeners] (#2774)
* Removes _SocketType as a public type, expanding SocketType to have "abstract" methods raising `NotImplementedError`. * Also replaces several methods taking `socket` as a type with taking `HasFileNo`. * Removes `Address` type alias due to complexity, using `Any` instead. * Add types to test_socket, test_highlevel_socket, test_highlevel_open_tcp_stream, test_highlevel_open_tcp_listeners. * Update `FakeSocket`s and `FakeSocketFactory`s to be compatible with signatures in `SocketType`. --------- Co-authored-by: CoolCat467 <52022020+CoolCat467@users.noreply.github.com>
1 parent 0eaa91a commit 24226c3

24 files changed

+810
-342
lines changed

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ exclude_lines =
2525
if t.TYPE_CHECKING:
2626
@overload
2727
class .*\bProtocol\b.*\):
28+
raise NotImplementedError
2829

2930
partial_branches =
3031
pragma: no branch

docs/source/conf.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@
5454
("py:class", "sync function"),
5555
# why aren't these found in stdlib?
5656
("py:class", "types.FrameType"),
57-
# TODO: temporary type
58-
("py:class", "_SocketType"),
5957
# these are not defined in https://docs.python.org/3/objects.inv
6058
("py:class", "socket.AddressFamily"),
6159
("py:class", "socket.SocketKind"),

docs/source/reference-io.rst

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -504,13 +504,6 @@ Socket objects
504504
* :meth:`~socket.socket.set_inheritable`
505505
* :meth:`~socket.socket.get_inheritable`
506506

507-
The internal SocketType
508-
~~~~~~~~~~~~~~~~~~~~~~~~~~
509-
.. autoclass:: _SocketType
510-
..
511-
TODO: adding `:members:` here gives error due to overload+_wraps on `sendto`
512-
TODO: rewrite ... all of the above when fixing _SocketType vs SocketType
513-
514507

515508
.. currentmodule:: trio
516509

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,11 @@ module = [
7979
"trio/_tests/test_exports",
8080
"trio/_tests/test_file_io",
8181
"trio/_tests/test_highlevel_generic",
82-
"trio/_tests/test_highlevel_open_tcp_listeners",
83-
"trio/_tests/test_highlevel_open_tcp_stream",
8482
"trio/_tests/test_highlevel_open_unix_stream",
8583
"trio/_tests/test_highlevel_serve_listeners",
86-
"trio/_tests/test_highlevel_socket",
8784
"trio/_tests/test_highlevel_ssl_helpers",
8885
"trio/_tests/test_path",
8986
"trio/_tests/test_scheduler_determinism",
90-
"trio/_tests/test_socket",
9187
"trio/_tests/test_ssl",
9288
"trio/_tests/test_subprocess",
9389
"trio/_tests/test_sync",

trio/_abc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing_extensions import Self
1313

1414
# both of these introduce circular imports if outside a TYPE_CHECKING guard
15-
from ._socket import _SocketType
15+
from ._socket import SocketType
1616
from .lowlevel import Task
1717

1818

@@ -211,10 +211,10 @@ class SocketFactory(metaclass=ABCMeta):
211211
@abstractmethod
212212
def socket(
213213
self,
214-
family: socket.AddressFamily | int | None = None,
215-
type: socket.SocketKind | int | None = None,
216-
proto: int | None = None,
217-
) -> _SocketType:
214+
family: socket.AddressFamily | int = socket.AF_INET,
215+
type: socket.SocketKind | int = socket.SOCK_STREAM,
216+
proto: int = 0,
217+
) -> SocketType:
218218
"""Create and return a socket object.
219219
220220
Your socket object must inherit from :class:`trio.socket.SocketType`,

trio/_core/_generated_io_epoll.py

Lines changed: 8 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

trio/_core/_generated_io_kqueue.py

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

trio/_core/_io_epoll.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
from ._wakeup_socketpair import WakeupSocketpair
1414

1515
if TYPE_CHECKING:
16-
from socket import socket
17-
1816
from typing_extensions import TypeAlias
1917

2018
from .._core import Abort, RaiseCancelT
19+
from .._file_io import _HasFileNo
2120

2221

2322
@attr.s(slots=True, eq=False)
@@ -290,7 +289,7 @@ def _update_registrations(self, fd: int) -> None:
290289
if not wanted_flags:
291290
del self._registered[fd]
292291

293-
async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None:
292+
async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None:
294293
if not isinstance(fd, int):
295294
fd = fd.fileno()
296295
waiters = self._registered[fd]
@@ -309,15 +308,15 @@ def abort(_: RaiseCancelT) -> Abort:
309308
await _core.wait_task_rescheduled(abort)
310309

311310
@_public
312-
async def wait_readable(self, fd: int | socket) -> None:
311+
async def wait_readable(self, fd: int | _HasFileNo) -> None:
313312
await self._epoll_wait(fd, "read_task")
314313

315314
@_public
316-
async def wait_writable(self, fd: int | socket) -> None:
315+
async def wait_writable(self, fd: int | _HasFileNo) -> None:
317316
await self._epoll_wait(fd, "write_task")
318317

319318
@_public
320-
def notify_closing(self, fd: int | socket) -> None:
319+
def notify_closing(self, fd: int | _HasFileNo) -> None:
321320
if not isinstance(fd, int):
322321
fd = fd.fileno()
323322
wake_all(

trio/_core/_io_kqueue.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
from ._wakeup_socketpair import WakeupSocketpair
1515

1616
if TYPE_CHECKING:
17-
from socket import socket
18-
1917
from typing_extensions import TypeAlias
2018

2119
from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
20+
from .._file_io import _HasFileNo
2221

2322
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
2423

@@ -149,7 +148,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
149148
# wait_task_rescheduled does not have its return type typed
150149
return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return]
151150

152-
async def _wait_common(self, fd: int | socket, filter: int) -> None:
151+
async def _wait_common(self, fd: int | _HasFileNo, filter: int) -> None:
153152
if not isinstance(fd, int):
154153
fd = fd.fileno()
155154
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
@@ -181,15 +180,15 @@ def abort(_: RaiseCancelT) -> Abort:
181180
await self.wait_kevent(fd, filter, abort)
182181

183182
@_public
184-
async def wait_readable(self, fd: int | socket) -> None:
183+
async def wait_readable(self, fd: int | _HasFileNo) -> None:
185184
await self._wait_common(fd, select.KQ_FILTER_READ)
186185

187186
@_public
188-
async def wait_writable(self, fd: int | socket) -> None:
187+
async def wait_writable(self, fd: int | _HasFileNo) -> None:
189188
await self._wait_common(fd, select.KQ_FILTER_WRITE)
190189

191190
@_public
192-
def notify_closing(self, fd: int | socket) -> None:
191+
def notify_closing(self, fd: int | _HasFileNo) -> None:
193192
if not isinstance(fd, int):
194193
fd = fd.fileno()
195194

trio/_dtls.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,26 @@
4343
from OpenSSL.SSL import Context
4444
from typing_extensions import Self, TypeAlias
4545

46-
from trio.socket import Address, _SocketType
46+
from trio.socket import SocketType
4747

4848
MAX_UDP_PACKET_SIZE = 65527
4949

5050

51-
def packet_header_overhead(sock: _SocketType) -> int:
51+
def packet_header_overhead(sock: SocketType) -> int:
5252
if sock.family == trio.socket.AF_INET:
5353
return 28
5454
else:
5555
return 48
5656

5757

58-
def worst_case_mtu(sock: _SocketType) -> int:
58+
def worst_case_mtu(sock: SocketType) -> int:
5959
if sock.family == trio.socket.AF_INET:
6060
return 576 - packet_header_overhead(sock)
6161
else:
6262
return 1280 - packet_header_overhead(sock)
6363

6464

65-
def best_guess_mtu(sock: _SocketType) -> int:
65+
def best_guess_mtu(sock: SocketType) -> int:
6666
return 1500 - packet_header_overhead(sock)
6767

6868

@@ -563,7 +563,7 @@ def _signable(*fields: bytes) -> bytes:
563563

564564

565565
def _make_cookie(
566-
key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes
566+
key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes
567567
) -> bytes:
568568
assert len(salt) == SALT_BYTES
569569
assert len(key) == KEY_BYTES
@@ -581,7 +581,7 @@ def _make_cookie(
581581

582582

583583
def valid_cookie(
584-
key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes
584+
key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes
585585
) -> bool:
586586
if len(cookie) > SALT_BYTES:
587587
salt = cookie[:SALT_BYTES]
@@ -603,7 +603,7 @@ def valid_cookie(
603603

604604

605605
def challenge_for(
606-
key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes
606+
key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes
607607
) -> bytes:
608608
salt = os.urandom(SALT_BYTES)
609609
tick = _current_cookie_tick()
@@ -664,7 +664,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
664664

665665

666666
async def handle_client_hello_untrusted(
667-
endpoint: DTLSEndpoint, address: Address, packet: bytes
667+
endpoint: DTLSEndpoint, address: Any, packet: bytes
668668
) -> None:
669669
if endpoint._listening_context is None:
670670
return
@@ -739,7 +739,7 @@ async def handle_client_hello_untrusted(
739739

740740

741741
async def dtls_receive_loop(
742-
endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType
742+
endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType
743743
) -> None:
744744
try:
745745
while True:
@@ -829,7 +829,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
829829
830830
"""
831831

832-
def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context):
832+
def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
833833
self.endpoint = endpoint
834834
self.peer_address = peer_address
835835
self._packets_dropped_in_trio = 0
@@ -1180,7 +1180,7 @@ class DTLSEndpoint:
11801180
11811181
"""
11821182

1183-
def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
1183+
def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10):
11841184
# We do this lazily on first construction, so only people who actually use DTLS
11851185
# have to install PyOpenSSL.
11861186
global SSL
@@ -1191,7 +1191,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
11911191
if socket.type != trio.socket.SOCK_DGRAM:
11921192
raise ValueError("DTLS requires a SOCK_DGRAM socket")
11931193
self._initialized = True
1194-
self.socket: _SocketType = socket
1194+
self.socket: SocketType = socket
11951195

11961196
self.incoming_packets_buffer = incoming_packets_buffer
11971197
self._token = trio.lowlevel.current_trio_token()
@@ -1200,7 +1200,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
12001200
# as a peer provides a valid cookie, we can immediately tear down the
12011201
# old connection.
12021202
# {remote address: DTLSChannel}
1203-
self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary()
1203+
self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary()
12041204
self._listening_context: Context | None = None
12051205
self._listening_key: bytes | None = None
12061206
self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))

0 commit comments

Comments
 (0)