Skip to content

gh-129107: make bytearray free-thread safe #129108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Include/cpython/bytearrayobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ static inline char* PyByteArray_AS_STRING(PyObject *op)

static inline Py_ssize_t PyByteArray_GET_SIZE(PyObject *op) {
PyByteArrayObject *self = _PyByteArray_CAST(op);
#ifdef Py_GIL_DISABLED
return _Py_atomic_load_ssize_relaxed(&(_PyVarObject_CAST(self)->ob_size));
#else
return Py_SIZE(self);
#endif
}
#define PyByteArray_GET_SIZE(self) PyByteArray_GET_SIZE(_PyObject_CAST(self))
335 changes: 335 additions & 0 deletions Lib/test/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
import copy
import functools
import pickle
import sysconfig
import tempfile
import textwrap
import threading
import unittest

import test.support
from test import support
from test.support import import_helper
from test.support import threading_helper
from test.support import warnings_helper
import test.string_tests
import test.list_tests
Expand Down Expand Up @@ -2185,5 +2189,336 @@ class BytesSubclassTest(SubclassTest, unittest.TestCase):
type2test = BytesSubclass


class FreeThreadingTest(unittest.TestCase):
@unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_free_threading_bytearray(self):
# Test pretty much everything that can break under free-threading.
# Non-deterministic, but at least one of these things will fail if
# bytearray module is not free-thread safe.

def clear(b, a, *args): # MODIFIES!
b.wait()
try: a.clear()
except BufferError: pass

def clear2(b, a, c): # MODIFIES c!
b.wait()
try: c.clear()
except BufferError: pass

def pop1(b, a): # MODIFIES!
b.wait()
try: a.pop()
except IndexError: pass

def append1(b, a): # MODIFIES!
b.wait()
a.append(0)

def insert1(b, a): # MODIFIES!
b.wait()
a.insert(0, 0)

def extend(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a.extend(c)

def remove(b, a): # MODIFIES!
c = ord('0')
b.wait()
try: a.remove(c)
except ValueError: pass

def reverse(b, a): # modifies inplace
b.wait()
a.reverse()

def reduce(b, a):
b.wait()
a.__reduce__()

def reduceex2(b, a):
b.wait()
a.__reduce_ex__(2)

def reduceex3(b, a):
b.wait()
c = a.__reduce_ex__(3)
assert not c[1] or 0xdd not in c[1][0]

def count0(b, a):
b.wait()
a.count(0)

def decode(b, a):
b.wait()
a.decode()

def find(b, a):
c = bytearray(b'0' * 0x40000)
b.wait()
a.find(c)

def hex(b, a):
b.wait()
a.hex('_')

def join(b, a):
b.wait()
a.join([b'1', b'2', b'3'])

def replace(b, a):
b.wait()
a.replace(b'0', b'')

def maketrans(b, a, c):
b.wait()
try: a.maketrans(a, c)
except ValueError: pass

def translate(b, a, c):
b.wait()
a.translate(c)

def copy(b, a):
b.wait()
c = a.copy()
if c: assert c[0] == 48 # '0'

def endswith(b, a):
b.wait()
assert not a.endswith(b'\xdd')

def index(b, a):
b.wait()
try: a.index(b'\xdd')
except ValueError: return
assert False

def lstrip(b, a):
b.wait()
assert not a.lstrip(b'0')

def partition(b, a):
b.wait()
assert not a.partition(b'\xdd')[2]

def removeprefix(b, a):
b.wait()
assert not a.removeprefix(b'0')

def removesuffix(b, a):
b.wait()
assert not a.removesuffix(b'0')

def rfind(b, a):
b.wait()
assert a.rfind(b'\xdd') == -1

def rindex(b, a):
b.wait()
try: a.rindex(b'\xdd')
except ValueError: return
assert False

def rpartition(b, a):
b.wait()
assert not a.rpartition(b'\xdd')[0]

def rsplit(b, a):
b.wait()
assert len(a.rsplit(b'\xdd')) == 1

def rstrip(b, a):
b.wait()
assert not a.rstrip(b'0')

def split(b, a):
b.wait()
assert len(a.split(b'\xdd')) == 1

def splitlines(b, a):
b.wait()
l = len(a.splitlines())
assert l > 1 or l == 0

def startswith(b, a):
b.wait()
assert not a.startswith(b'\xdd')

def strip(b, a):
b.wait()
assert not a.strip(b'0')

def repeat(b, a):
b.wait()
a * 2

def contains(b, a):
b.wait()
assert 0xdd not in a

def iconcat(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a += c

def irepeat(b, a): # MODIFIES!
b.wait()
a *= 2

def subscript(b, a):
b.wait()
try: assert a[0] != 0xdd
except IndexError: pass

def ass_subscript(b, a): # MODIFIES!
c = bytearray(b'0' * 0x400000)
b.wait()
a[:] = c

def mod(b, a):
c = tuple(range(4096))
b.wait()
try: a % c
except TypeError: pass

def repr_(b, a):
b.wait()
repr(a)

def capitalize(b, a):
b.wait()
c = a.capitalize()
assert not c or c[0] not in (0xdd, 0xcd)

def center(b, a):
b.wait()
c = a.center(0x60000)
assert not c or c[0x20000] not in (0xdd, 0xcd)

def expandtabs(b, a):
b.wait()
c = a.expandtabs()
assert not c or c[0] not in (0xdd, 0xcd)

def ljust(b, a):
b.wait()
c = a.ljust(0x600000)
assert not c or c[0] not in (0xdd, 0xcd)

def lower(b, a):
b.wait()
c = a.lower()
assert not c or c[0] not in (0xdd, 0xcd)

def rjust(b, a):
b.wait()
c = a.rjust(0x600000)
assert not c or c[-1] not in (0xdd, 0xcd)

def swapcase(b, a):
b.wait()
c = a.swapcase()
assert not c or c[-1] not in (0xdd, 0xcd)

def title(b, a):
b.wait()
c = a.title()
assert not c or c[-1] not in (0xdd, 0xcd)

def upper(b, a):
b.wait()
c = a.upper()
assert not c or c[-1] not in (0xdd, 0xcd)

def zfill(b, a):
b.wait()
c = a.zfill(0x400000)
assert not c or c[-1] not in (0xdd, 0xcd)

def check(funcs, a=None, *args):
if a is None:
a = bytearray(b'0' * 0x400000)

barrier = threading.Barrier(len(funcs))
threads = []

for func in funcs:
thread = threading.Thread(target=func, args=(barrier, a, *args))

threads.append(thread)

with threading_helper.start_threads(threads):
pass

for thread in threads:
threading_helper.join_thread(thread)

# hard errors

check([clear] + [reduce] * 10)
check([clear] + [reduceex2] * 10)
check([clear] + [append1] * 10)
check([clear] * 10)
check([clear] + [count0] * 10)
check([clear] + [decode] * 10)
check([clear] + [extend] * 10)
check([clear] + [find] * 10)
check([clear] + [hex] * 10)
check([clear] + [insert1] * 10)
check([clear] + [join] * 10)
check([clear] + [pop1] * 10)
check([clear] + [remove] * 10)
check([clear] + [replace] * 10)
check([clear] + [reverse] * 10)
check([clear, clear2] + [maketrans] * 10, bytearray(range(128)), bytearray(range(128)))
check([clear] + [translate] * 10, None, bytearray.maketrans(bytearray(range(128)), bytearray(range(128))))

check([clear] + [repeat] * 10)
check([clear] + [iconcat] * 10)
check([clear] + [irepeat] * 10)
check([clear] + [ass_subscript] * 10)
check([clear] + [repr_] * 10)

# value errors

check([clear] + [reduceex3] * 10, bytearray(b'a' * 0x40000))
check([clear] + [copy] * 10)
check([clear] + [endswith] * 10)
check([clear] + [index] * 10)
check([clear] + [lstrip] * 10)
check([clear] + [partition] * 10)
check([clear] + [removeprefix] * 10, bytearray(b'0'))
check([clear] + [removesuffix] * 10, bytearray(b'0'))
check([clear] + [rfind] * 10)
check([clear] + [rindex] * 10)
check([clear] + [rpartition] * 10)
check([clear] + [rsplit] * 10, bytearray(b'0' * 0x4000))
check([clear] + [rstrip] * 10)
check([clear] + [split] * 10, bytearray(b'0' * 0x4000))
check([clear] + [splitlines] * 10, bytearray(b'\n' * 0x400))
check([clear] + [startswith] * 10)
check([clear] + [strip] * 10)

check([clear] + [contains] * 10)
check([clear] + [subscript] * 10)
check([clear] + [mod] * 10, bytearray(b'%d' * 4096))

check([clear] + [capitalize] * 10, bytearray(b'a' * 0x40000))
check([clear] + [center] * 10, bytearray(b'a' * 0x40000))
check([clear] + [expandtabs] * 10, bytearray(b'0\t' * 4096))
check([clear] + [ljust] * 10, bytearray(b'0' * 0x400000))
check([clear] + [lower] * 10, bytearray(b'A' * 0x400000))
check([clear] + [rjust] * 10, bytearray(b'0' * 0x400000))
check([clear] + [swapcase] * 10, bytearray(b'aA' * 0x200000))
check([clear] + [title] * 10, bytearray(b'aA' * 0x200000))
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make the :type:`bytearray` safe under :term:`free threading`.
Loading
Loading