Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ time.

## Bug Fixes

* [`Broadcast`] receivers now get cleaned up once they go out of scope.

* [`Timer`] now returns [timezone-aware] `datetime` objects using UTC as
timezone.


[`Broadcast`]: https://frequenz-floss.github.io/frequenz-channels-python/v0.11/reference/frequenz/channels/#frequenz.channels.Broadcast
[`Timer`]: https://frequenz-floss.github.io/frequenz-channels-python/v0.11/reference/frequenz/channels/#frequenz.channels.Timer
[timezone-aware]: https://docs.python.org/3/library/datetime.html#aware-and-naive-objects
31 changes: 11 additions & 20 deletions src/frequenz/channels/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import logging
import weakref
from asyncio import Condition
from collections import deque
from typing import Deque, Dict, Generic, Optional
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self, name: str, resend_latest: bool = False) -> None:
self._resend_latest = resend_latest

self.recv_cv: Condition = Condition()
self.receivers: Dict[UUID, Receiver[T]] = {}
self.receivers: Dict[UUID, weakref.ReferenceType[Receiver[T]]] = {}
self.closed: bool = False
self._latest: Optional[T] = None

Expand All @@ -101,17 +102,6 @@ async def close(self) -> None:
async with self.recv_cv:
self.recv_cv.notify_all()

def _drop_receiver(self, uuid: UUID) -> None:
"""Drop a specific receiver from the list of broadcast receivers.

Called from the destructors of receivers.

Args:
uuid: a uuid identifying the receiver to be dropped.
"""
if uuid in self.receivers:
del self.receivers[uuid]

def get_sender(self) -> Sender[T]:
"""Create a new broadcast sender.

Expand Down Expand Up @@ -140,7 +130,7 @@ def get_receiver(
if name is None:
name = str(uuid)
recv: Receiver[T] = Receiver(uuid, name, maxsize, self)
self.receivers[uuid] = recv
self.receivers[uuid] = weakref.ref(recv)
if self._resend_latest and self._latest is not None:
recv.enqueue(self._latest)
return recv
Expand Down Expand Up @@ -188,8 +178,15 @@ async def send(self, msg: T) -> bool:
return False
# pylint: disable=protected-access
self._chan._latest = msg
for recv in self._chan.receivers.values():
stale_refs = []
for name, recv_ref in self._chan.receivers.items():
recv = recv_ref()
if recv is None:
stale_refs.append(name)
continue
recv.enqueue(msg)
for name in stale_refs:
del self._chan.receivers[name]
async with self._chan.recv_cv:
self._chan.recv_cv.notify_all()
return True
Expand Down Expand Up @@ -225,11 +222,6 @@ def __init__(self, uuid: UUID, name: str, maxsize: int, chan: Broadcast[T]) -> N

self._active = True

def __del__(self) -> None:
"""Drop this receiver from the list of Broadcast receivers."""
if self._active:
self._chan._drop_receiver(self._uuid)

def enqueue(self, msg: T) -> None:
"""Put a message into this receiver's queue.

Expand Down Expand Up @@ -295,7 +287,6 @@ def into_peekable(self) -> Peekable[T]:
Returns:
A `Peekable` instance.
"""
self._chan._drop_receiver(self._uuid) # pylint: disable=protected-access
self._active = False
return Peekable(self._chan)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,25 @@ async def test_broadcast_map() -> None:

assert (await receiver.receive()) is False
assert (await receiver.receive()) is True


async def test_broadcast_receiver_drop() -> None:
"""Ensure deleted receivers get cleaned up."""
chan = Broadcast[int]("input-chan")
sender = chan.get_sender()

receiver1 = chan.get_receiver()
receiver2 = chan.get_receiver()

await sender.send(10)

assert 10 == await receiver1.receive()
assert 10 == await receiver2.receive()

assert len(chan.receivers) == 2

del receiver2

await sender.send(20)

assert len(chan.receivers) == 1