Skip to content

gh-123089: avoid _IterationGuard for WeakSet, make it thread safe #123279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 10 additions & 43 deletions Lib/_weakrefset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,26 @@ def __exit__(self, e, t, b):
class WeakSet:
def __init__(self, data=None):
self.data = set()

def _remove(item, selfref=ref(self)):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(item)
else:
self.data.discard(item)
self.data.discard(item)

self._remove = _remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
if data is not None:
self.update(data)

def _commit_removals(self):
pop = self._pending_removals.pop
discard = self.data.discard
while True:
try:
item = pop()
except IndexError:
return
discard(item)

def __iter__(self):
with _IterationGuard(self):
for itemref in self.data:
item = itemref()
if item is not None:
# Caveat: the iterator will keep a strong reference to
# `item` until it is resumed or closed.
yield item
for itemref in self.data.copy():
item = itemref()
if item is not None:
# Caveat: the iterator will keep a strong reference to
# `item` until it is resumed or closed.
yield item

def __len__(self):
return len(self.data) - len(self._pending_removals)
return len(self.data)

def __contains__(self, item):
try:
Expand All @@ -83,21 +68,15 @@ def __reduce__(self):
return self.__class__, (list(self),), self.__getstate__()

def add(self, item):
if self._pending_removals:
self._commit_removals()
self.data.add(ref(item, self._remove))

def clear(self):
if self._pending_removals:
self._commit_removals()
self.data.clear()

def copy(self):
return self.__class__(self)

def pop(self):
if self._pending_removals:
self._commit_removals()
while True:
try:
itemref = self.data.pop()
Expand All @@ -108,18 +87,12 @@ def pop(self):
return item

def remove(self, item):
if self._pending_removals:
self._commit_removals()
self.data.remove(ref(item))

def discard(self, item):
if self._pending_removals:
self._commit_removals()
self.data.discard(ref(item))

def update(self, other):
if self._pending_removals:
self._commit_removals()
for element in other:
self.add(element)

Expand All @@ -136,8 +109,6 @@ def difference(self, other):
def difference_update(self, other):
self.__isub__(other)
def __isub__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
Expand All @@ -151,8 +122,6 @@ def intersection(self, other):
def intersection_update(self, other):
self.__iand__(other)
def __iand__(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
return self

Expand Down Expand Up @@ -184,8 +153,6 @@ def symmetric_difference(self, other):
def symmetric_difference_update(self, other):
self.__ixor__(other)
def __ixor__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :class:`weakref.WeakSet` safe against concurrent mutations while it is being iterated. Patch by Kumar Aditya.
Loading