Skip to content

Commit a05433f

Browse files
gh-129107: make bytearray thread safe (#129108)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
1 parent d2e60d8 commit a05433f

File tree

5 files changed

+904
-100
lines changed

5 files changed

+904
-100
lines changed

Include/cpython/bytearrayobject.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ static inline char* PyByteArray_AS_STRING(PyObject *op)
2929

3030
static inline Py_ssize_t PyByteArray_GET_SIZE(PyObject *op) {
3131
PyByteArrayObject *self = _PyByteArray_CAST(op);
32+
#ifdef Py_GIL_DISABLED
33+
return _Py_atomic_load_ssize_relaxed(&(_PyVarObject_CAST(self)->ob_size));
34+
#else
3235
return Py_SIZE(self);
36+
#endif
3337
}
3438
#define PyByteArray_GET_SIZE(self) PyByteArray_GET_SIZE(_PyObject_CAST(self))

Lib/test/test_bytes.py

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@
1111
import copy
1212
import functools
1313
import pickle
14+
import sysconfig
1415
import tempfile
1516
import textwrap
17+
import threading
1618
import unittest
1719

1820
import test.support
21+
from test import support
1922
from test.support import import_helper
23+
from test.support import threading_helper
2024
from test.support import warnings_helper
2125
import test.string_tests
2226
import test.list_tests
@@ -2185,5 +2189,336 @@ class BytesSubclassTest(SubclassTest, unittest.TestCase):
21852189
type2test = BytesSubclass
21862190

21872191

2192+
class FreeThreadingTest(unittest.TestCase):
2193+
@unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
2194+
@threading_helper.reap_threads
2195+
@threading_helper.requires_working_threading()
2196+
def test_free_threading_bytearray(self):
2197+
# Test pretty much everything that can break under free-threading.
2198+
# Non-deterministic, but at least one of these things will fail if
2199+
# bytearray module is not free-thread safe.
2200+
2201+
def clear(b, a, *args): # MODIFIES!
2202+
b.wait()
2203+
try: a.clear()
2204+
except BufferError: pass
2205+
2206+
def clear2(b, a, c): # MODIFIES c!
2207+
b.wait()
2208+
try: c.clear()
2209+
except BufferError: pass
2210+
2211+
def pop1(b, a): # MODIFIES!
2212+
b.wait()
2213+
try: a.pop()
2214+
except IndexError: pass
2215+
2216+
def append1(b, a): # MODIFIES!
2217+
b.wait()
2218+
a.append(0)
2219+
2220+
def insert1(b, a): # MODIFIES!
2221+
b.wait()
2222+
a.insert(0, 0)
2223+
2224+
def extend(b, a): # MODIFIES!
2225+
c = bytearray(b'0' * 0x400000)
2226+
b.wait()
2227+
a.extend(c)
2228+
2229+
def remove(b, a): # MODIFIES!
2230+
c = ord('0')
2231+
b.wait()
2232+
try: a.remove(c)
2233+
except ValueError: pass
2234+
2235+
def reverse(b, a): # modifies inplace
2236+
b.wait()
2237+
a.reverse()
2238+
2239+
def reduce(b, a):
2240+
b.wait()
2241+
a.__reduce__()
2242+
2243+
def reduceex2(b, a):
2244+
b.wait()
2245+
a.__reduce_ex__(2)
2246+
2247+
def reduceex3(b, a):
2248+
b.wait()
2249+
c = a.__reduce_ex__(3)
2250+
assert not c[1] or 0xdd not in c[1][0]
2251+
2252+
def count0(b, a):
2253+
b.wait()
2254+
a.count(0)
2255+
2256+
def decode(b, a):
2257+
b.wait()
2258+
a.decode()
2259+
2260+
def find(b, a):
2261+
c = bytearray(b'0' * 0x40000)
2262+
b.wait()
2263+
a.find(c)
2264+
2265+
def hex(b, a):
2266+
b.wait()
2267+
a.hex('_')
2268+
2269+
def join(b, a):
2270+
b.wait()
2271+
a.join([b'1', b'2', b'3'])
2272+
2273+
def replace(b, a):
2274+
b.wait()
2275+
a.replace(b'0', b'')
2276+
2277+
def maketrans(b, a, c):
2278+
b.wait()
2279+
try: a.maketrans(a, c)
2280+
except ValueError: pass
2281+
2282+
def translate(b, a, c):
2283+
b.wait()
2284+
a.translate(c)
2285+
2286+
def copy(b, a):
2287+
b.wait()
2288+
c = a.copy()
2289+
if c: assert c[0] == 48 # '0'
2290+
2291+
def endswith(b, a):
2292+
b.wait()
2293+
assert not a.endswith(b'\xdd')
2294+
2295+
def index(b, a):
2296+
b.wait()
2297+
try: a.index(b'\xdd')
2298+
except ValueError: return
2299+
assert False
2300+
2301+
def lstrip(b, a):
2302+
b.wait()
2303+
assert not a.lstrip(b'0')
2304+
2305+
def partition(b, a):
2306+
b.wait()
2307+
assert not a.partition(b'\xdd')[2]
2308+
2309+
def removeprefix(b, a):
2310+
b.wait()
2311+
assert not a.removeprefix(b'0')
2312+
2313+
def removesuffix(b, a):
2314+
b.wait()
2315+
assert not a.removesuffix(b'0')
2316+
2317+
def rfind(b, a):
2318+
b.wait()
2319+
assert a.rfind(b'\xdd') == -1
2320+
2321+
def rindex(b, a):
2322+
b.wait()
2323+
try: a.rindex(b'\xdd')
2324+
except ValueError: return
2325+
assert False
2326+
2327+
def rpartition(b, a):
2328+
b.wait()
2329+
assert not a.rpartition(b'\xdd')[0]
2330+
2331+
def rsplit(b, a):
2332+
b.wait()
2333+
assert len(a.rsplit(b'\xdd')) == 1
2334+
2335+
def rstrip(b, a):
2336+
b.wait()
2337+
assert not a.rstrip(b'0')
2338+
2339+
def split(b, a):
2340+
b.wait()
2341+
assert len(a.split(b'\xdd')) == 1
2342+
2343+
def splitlines(b, a):
2344+
b.wait()
2345+
l = len(a.splitlines())
2346+
assert l > 1 or l == 0
2347+
2348+
def startswith(b, a):
2349+
b.wait()
2350+
assert not a.startswith(b'\xdd')
2351+
2352+
def strip(b, a):
2353+
b.wait()
2354+
assert not a.strip(b'0')
2355+
2356+
def repeat(b, a):
2357+
b.wait()
2358+
a * 2
2359+
2360+
def contains(b, a):
2361+
b.wait()
2362+
assert 0xdd not in a
2363+
2364+
def iconcat(b, a): # MODIFIES!
2365+
c = bytearray(b'0' * 0x400000)
2366+
b.wait()
2367+
a += c
2368+
2369+
def irepeat(b, a): # MODIFIES!
2370+
b.wait()
2371+
a *= 2
2372+
2373+
def subscript(b, a):
2374+
b.wait()
2375+
try: assert a[0] != 0xdd
2376+
except IndexError: pass
2377+
2378+
def ass_subscript(b, a): # MODIFIES!
2379+
c = bytearray(b'0' * 0x400000)
2380+
b.wait()
2381+
a[:] = c
2382+
2383+
def mod(b, a):
2384+
c = tuple(range(4096))
2385+
b.wait()
2386+
try: a % c
2387+
except TypeError: pass
2388+
2389+
def repr_(b, a):
2390+
b.wait()
2391+
repr(a)
2392+
2393+
def capitalize(b, a):
2394+
b.wait()
2395+
c = a.capitalize()
2396+
assert not c or c[0] not in (0xdd, 0xcd)
2397+
2398+
def center(b, a):
2399+
b.wait()
2400+
c = a.center(0x60000)
2401+
assert not c or c[0x20000] not in (0xdd, 0xcd)
2402+
2403+
def expandtabs(b, a):
2404+
b.wait()
2405+
c = a.expandtabs()
2406+
assert not c or c[0] not in (0xdd, 0xcd)
2407+
2408+
def ljust(b, a):
2409+
b.wait()
2410+
c = a.ljust(0x600000)
2411+
assert not c or c[0] not in (0xdd, 0xcd)
2412+
2413+
def lower(b, a):
2414+
b.wait()
2415+
c = a.lower()
2416+
assert not c or c[0] not in (0xdd, 0xcd)
2417+
2418+
def rjust(b, a):
2419+
b.wait()
2420+
c = a.rjust(0x600000)
2421+
assert not c or c[-1] not in (0xdd, 0xcd)
2422+
2423+
def swapcase(b, a):
2424+
b.wait()
2425+
c = a.swapcase()
2426+
assert not c or c[-1] not in (0xdd, 0xcd)
2427+
2428+
def title(b, a):
2429+
b.wait()
2430+
c = a.title()
2431+
assert not c or c[-1] not in (0xdd, 0xcd)
2432+
2433+
def upper(b, a):
2434+
b.wait()
2435+
c = a.upper()
2436+
assert not c or c[-1] not in (0xdd, 0xcd)
2437+
2438+
def zfill(b, a):
2439+
b.wait()
2440+
c = a.zfill(0x400000)
2441+
assert not c or c[-1] not in (0xdd, 0xcd)
2442+
2443+
def check(funcs, a=None, *args):
2444+
if a is None:
2445+
a = bytearray(b'0' * 0x400000)
2446+
2447+
barrier = threading.Barrier(len(funcs))
2448+
threads = []
2449+
2450+
for func in funcs:
2451+
thread = threading.Thread(target=func, args=(barrier, a, *args))
2452+
2453+
threads.append(thread)
2454+
2455+
with threading_helper.start_threads(threads):
2456+
pass
2457+
2458+
for thread in threads:
2459+
threading_helper.join_thread(thread)
2460+
2461+
# hard errors
2462+
2463+
check([clear] + [reduce] * 10)
2464+
check([clear] + [reduceex2] * 10)
2465+
check([clear] + [append1] * 10)
2466+
check([clear] * 10)
2467+
check([clear] + [count0] * 10)
2468+
check([clear] + [decode] * 10)
2469+
check([clear] + [extend] * 10)
2470+
check([clear] + [find] * 10)
2471+
check([clear] + [hex] * 10)
2472+
check([clear] + [insert1] * 10)
2473+
check([clear] + [join] * 10)
2474+
check([clear] + [pop1] * 10)
2475+
check([clear] + [remove] * 10)
2476+
check([clear] + [replace] * 10)
2477+
check([clear] + [reverse] * 10)
2478+
check([clear, clear2] + [maketrans] * 10, bytearray(range(128)), bytearray(range(128)))
2479+
check([clear] + [translate] * 10, None, bytearray.maketrans(bytearray(range(128)), bytearray(range(128))))
2480+
2481+
check([clear] + [repeat] * 10)
2482+
check([clear] + [iconcat] * 10)
2483+
check([clear] + [irepeat] * 10)
2484+
check([clear] + [ass_subscript] * 10)
2485+
check([clear] + [repr_] * 10)
2486+
2487+
# value errors
2488+
2489+
check([clear] + [reduceex3] * 10, bytearray(b'a' * 0x40000))
2490+
check([clear] + [copy] * 10)
2491+
check([clear] + [endswith] * 10)
2492+
check([clear] + [index] * 10)
2493+
check([clear] + [lstrip] * 10)
2494+
check([clear] + [partition] * 10)
2495+
check([clear] + [removeprefix] * 10, bytearray(b'0'))
2496+
check([clear] + [removesuffix] * 10, bytearray(b'0'))
2497+
check([clear] + [rfind] * 10)
2498+
check([clear] + [rindex] * 10)
2499+
check([clear] + [rpartition] * 10)
2500+
check([clear] + [rsplit] * 10, bytearray(b'0' * 0x4000))
2501+
check([clear] + [rstrip] * 10)
2502+
check([clear] + [split] * 10, bytearray(b'0' * 0x4000))
2503+
check([clear] + [splitlines] * 10, bytearray(b'\n' * 0x400))
2504+
check([clear] + [startswith] * 10)
2505+
check([clear] + [strip] * 10)
2506+
2507+
check([clear] + [contains] * 10)
2508+
check([clear] + [subscript] * 10)
2509+
check([clear] + [mod] * 10, bytearray(b'%d' * 4096))
2510+
2511+
check([clear] + [capitalize] * 10, bytearray(b'a' * 0x40000))
2512+
check([clear] + [center] * 10, bytearray(b'a' * 0x40000))
2513+
check([clear] + [expandtabs] * 10, bytearray(b'0\t' * 4096))
2514+
check([clear] + [ljust] * 10, bytearray(b'0' * 0x400000))
2515+
check([clear] + [lower] * 10, bytearray(b'A' * 0x400000))
2516+
check([clear] + [rjust] * 10, bytearray(b'0' * 0x400000))
2517+
check([clear] + [swapcase] * 10, bytearray(b'aA' * 0x200000))
2518+
check([clear] + [title] * 10, bytearray(b'aA' * 0x200000))
2519+
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
2520+
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))
2521+
2522+
21882523
if __name__ == "__main__":
21892524
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make the :type:`bytearray` safe under :term:`free threading`.

0 commit comments

Comments
 (0)