From d32fcba1c73cf385277dc13aaf26080916aa24de Mon Sep 17 00:00:00 2001 From: Tobias Petersen Date: Wed, 4 Sep 2024 13:14:20 +0200 Subject: [PATCH] Accept abstract namespace paths for unix domain sockets Accept paths starting with a null byte in create_unix_listener and connect_unix_socket to allow creating abstract unix sockets. Fixes #781 --- src/anyio/_core/_sockets.py | 8 +++++-- tests/test_sockets.py | 44 +++++++++++++++++++++++++++++-------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 647597b6..388a869b 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -671,8 +671,12 @@ async def setup_unix_local_socket( if path is not None: path_str = str(path) path = Path(path) - if path.is_socket(): - path.unlink() + if path_str.startswith("\0"): + # Unix abstract namespace socket. No file backing so skip stat call + pass + else: + if path.is_socket(): + path.unlink() else: path_str = None diff --git a/tests/test_sockets.py b/tests/test_sockets.py index acffe920..cfbc496d 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -695,9 +695,16 @@ async def handle(stream: SocketStream) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXStream: - @pytest.fixture - def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("socket") + @pytest.fixture(params=["path", "abstract"]) + def socket_path( + self, request: SubRequest, tmp_path_factory: TempPathFactory + ) -> Path: + path = tmp_path_factory.mktemp("unix").joinpath("socket") + + if request.param == "path": + return path + elif request.param == "abstract": + return Path(f"\0{path}") @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -721,7 +728,15 @@ async def test_extra_attributes( assert ( stream.extra(SocketAttribute.local_address) == raw_socket.getsockname() ) - assert stream.extra(SocketAttribute.remote_address) == str(socket_path) + remote_addr = stream.extra(SocketAttribute.remote_address) + if isinstance(remote_addr, str): + assert stream.extra(SocketAttribute.remote_address) == str(socket_path) + else: + assert isinstance(remote_addr, bytes) + assert stream.extra(SocketAttribute.remote_address) == bytes( + socket_path + ) + pytest.raises( TypedAttributeLookupError, stream.extra, SocketAttribute.local_port ) @@ -960,17 +975,28 @@ async def test_send_after_close( await stream.send(b"foo") async def test_cannot_connect(self, socket_path: Path) -> None: - with pytest.raises(FileNotFoundError): - await connect_unix(socket_path) + if str(socket_path).startswith("\0"): + with pytest.raises(ConnectionRefusedError): + await connect_unix(socket_path) + else: + with pytest.raises(FileNotFoundError): + await connect_unix(socket_path) @pytest.mark.skipif( sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXListener: - @pytest.fixture - def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("socket") + @pytest.fixture(params=["path", "abstract"]) + def socket_path( + self, request: SubRequest, tmp_path_factory: TempPathFactory + ) -> Path: + path = tmp_path_factory.mktemp("unix").joinpath("socket") + + if request.param == "path": + return path + elif request.param == "abstract": + return Path(f"\0{path}") @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: