Skip to content

Commit

Permalink
Accept abstract namespace paths for unix domain sockets
Browse files Browse the repository at this point in the history
Accept paths starting with a null byte in create_unix_listener and
connect_unix_socket to allow creating abstract unix sockets. Fixes agronholm#781
  • Loading branch information
tapetersen committed Sep 4, 2024
1 parent 7ef43a6 commit d32fcba
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
8 changes: 6 additions & 2 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,12 @@ async def setup_unix_local_socket(
if path is not None:
path_str = str(path)
path = Path(path)
if path.is_socket():
path.unlink()
if path_str.startswith("\0"):
# Unix abstract namespace socket. No file backing so skip stat call
pass
else:
if path.is_socket():
path.unlink()
else:
path_str = None

Expand Down
44 changes: 35 additions & 9 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,9 +695,16 @@ async def handle(stream: SocketStream) -> None:
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXStream:
@pytest.fixture
def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
return tmp_path_factory.mktemp("unix").joinpath("socket")
@pytest.fixture(params=["path", "abstract"])
def socket_path(
self, request: SubRequest, tmp_path_factory: TempPathFactory
) -> Path:
path = tmp_path_factory.mktemp("unix").joinpath("socket")

if request.param == "path":
return path
elif request.param == "abstract":
return Path(f"\0{path}")

@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
Expand All @@ -721,7 +728,15 @@ async def test_extra_attributes(
assert (
stream.extra(SocketAttribute.local_address) == raw_socket.getsockname()
)
assert stream.extra(SocketAttribute.remote_address) == str(socket_path)
remote_addr = stream.extra(SocketAttribute.remote_address)
if isinstance(remote_addr, str):
assert stream.extra(SocketAttribute.remote_address) == str(socket_path)
else:
assert isinstance(remote_addr, bytes)
assert stream.extra(SocketAttribute.remote_address) == bytes(
socket_path
)

pytest.raises(
TypedAttributeLookupError, stream.extra, SocketAttribute.local_port
)
Expand Down Expand Up @@ -960,17 +975,28 @@ async def test_send_after_close(
await stream.send(b"foo")

async def test_cannot_connect(self, socket_path: Path) -> None:
with pytest.raises(FileNotFoundError):
await connect_unix(socket_path)
if str(socket_path).startswith("\0"):
with pytest.raises(ConnectionRefusedError):
await connect_unix(socket_path)
else:
with pytest.raises(FileNotFoundError):
await connect_unix(socket_path)


@pytest.mark.skipif(
sys.platform == "win32", reason="UNIX sockets are not available on Windows"
)
class TestUNIXListener:
@pytest.fixture
def socket_path(self, tmp_path_factory: TempPathFactory) -> Path:
return tmp_path_factory.mktemp("unix").joinpath("socket")
@pytest.fixture(params=["path", "abstract"])
def socket_path(
self, request: SubRequest, tmp_path_factory: TempPathFactory
) -> Path:
path = tmp_path_factory.mktemp("unix").joinpath("socket")

if request.param == "path":
return path
elif request.param == "abstract":
return Path(f"\0{path}")

@pytest.fixture(params=[False, True], ids=["str", "path"])
def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str:
Expand Down

0 comments on commit d32fcba

Please sign in to comment.