|
11 | 11 | import copy
|
12 | 12 | import functools
|
13 | 13 | import pickle
|
| 14 | +import sysconfig |
14 | 15 | import tempfile
|
15 | 16 | import textwrap
|
| 17 | +import threading |
16 | 18 | import unittest
|
17 | 19 |
|
18 | 20 | import test.support
|
| 21 | +from test import support |
19 | 22 | from test.support import import_helper
|
| 23 | +from test.support import threading_helper |
20 | 24 | from test.support import warnings_helper
|
21 | 25 | import test.string_tests
|
22 | 26 | import test.list_tests
|
@@ -2185,5 +2189,336 @@ class BytesSubclassTest(SubclassTest, unittest.TestCase):
|
2185 | 2189 | type2test = BytesSubclass
|
2186 | 2190 |
|
2187 | 2191 |
|
| 2192 | +class FreeThreadingTest(unittest.TestCase): |
| 2193 | + @unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') |
| 2194 | + @threading_helper.reap_threads |
| 2195 | + @threading_helper.requires_working_threading() |
| 2196 | + def test_free_threading_bytearray(self): |
| 2197 | + # Test pretty much everything that can break under free-threading. |
| 2198 | + # Non-deterministic, but at least one of these things will fail if |
| 2199 | + # bytearray module is not free-thread safe. |
| 2200 | + |
| 2201 | + def clear(b, a, *args): # MODIFIES! |
| 2202 | + b.wait() |
| 2203 | + try: a.clear() |
| 2204 | + except BufferError: pass |
| 2205 | + |
| 2206 | + def clear2(b, a, c): # MODIFIES c! |
| 2207 | + b.wait() |
| 2208 | + try: c.clear() |
| 2209 | + except BufferError: pass |
| 2210 | + |
| 2211 | + def pop1(b, a): # MODIFIES! |
| 2212 | + b.wait() |
| 2213 | + try: a.pop() |
| 2214 | + except IndexError: pass |
| 2215 | + |
| 2216 | + def append1(b, a): # MODIFIES! |
| 2217 | + b.wait() |
| 2218 | + a.append(0) |
| 2219 | + |
| 2220 | + def insert1(b, a): # MODIFIES! |
| 2221 | + b.wait() |
| 2222 | + a.insert(0, 0) |
| 2223 | + |
| 2224 | + def extend(b, a): # MODIFIES! |
| 2225 | + c = bytearray(b'0' * 0x400000) |
| 2226 | + b.wait() |
| 2227 | + a.extend(c) |
| 2228 | + |
| 2229 | + def remove(b, a): # MODIFIES! |
| 2230 | + c = ord('0') |
| 2231 | + b.wait() |
| 2232 | + try: a.remove(c) |
| 2233 | + except ValueError: pass |
| 2234 | + |
| 2235 | + def reverse(b, a): # modifies inplace |
| 2236 | + b.wait() |
| 2237 | + a.reverse() |
| 2238 | + |
| 2239 | + def reduce(b, a): |
| 2240 | + b.wait() |
| 2241 | + a.__reduce__() |
| 2242 | + |
| 2243 | + def reduceex2(b, a): |
| 2244 | + b.wait() |
| 2245 | + a.__reduce_ex__(2) |
| 2246 | + |
| 2247 | + def reduceex3(b, a): |
| 2248 | + b.wait() |
| 2249 | + c = a.__reduce_ex__(3) |
| 2250 | + assert not c[1] or 0xdd not in c[1][0] |
| 2251 | + |
| 2252 | + def count0(b, a): |
| 2253 | + b.wait() |
| 2254 | + a.count(0) |
| 2255 | + |
| 2256 | + def decode(b, a): |
| 2257 | + b.wait() |
| 2258 | + a.decode() |
| 2259 | + |
| 2260 | + def find(b, a): |
| 2261 | + c = bytearray(b'0' * 0x40000) |
| 2262 | + b.wait() |
| 2263 | + a.find(c) |
| 2264 | + |
| 2265 | + def hex(b, a): |
| 2266 | + b.wait() |
| 2267 | + a.hex('_') |
| 2268 | + |
| 2269 | + def join(b, a): |
| 2270 | + b.wait() |
| 2271 | + a.join([b'1', b'2', b'3']) |
| 2272 | + |
| 2273 | + def replace(b, a): |
| 2274 | + b.wait() |
| 2275 | + a.replace(b'0', b'') |
| 2276 | + |
| 2277 | + def maketrans(b, a, c): |
| 2278 | + b.wait() |
| 2279 | + try: a.maketrans(a, c) |
| 2280 | + except ValueError: pass |
| 2281 | + |
| 2282 | + def translate(b, a, c): |
| 2283 | + b.wait() |
| 2284 | + a.translate(c) |
| 2285 | + |
| 2286 | + def copy(b, a): |
| 2287 | + b.wait() |
| 2288 | + c = a.copy() |
| 2289 | + if c: assert c[0] == 48 # '0' |
| 2290 | + |
| 2291 | + def endswith(b, a): |
| 2292 | + b.wait() |
| 2293 | + assert not a.endswith(b'\xdd') |
| 2294 | + |
| 2295 | + def index(b, a): |
| 2296 | + b.wait() |
| 2297 | + try: a.index(b'\xdd') |
| 2298 | + except ValueError: return |
| 2299 | + assert False |
| 2300 | + |
| 2301 | + def lstrip(b, a): |
| 2302 | + b.wait() |
| 2303 | + assert not a.lstrip(b'0') |
| 2304 | + |
| 2305 | + def partition(b, a): |
| 2306 | + b.wait() |
| 2307 | + assert not a.partition(b'\xdd')[2] |
| 2308 | + |
| 2309 | + def removeprefix(b, a): |
| 2310 | + b.wait() |
| 2311 | + assert not a.removeprefix(b'0') |
| 2312 | + |
| 2313 | + def removesuffix(b, a): |
| 2314 | + b.wait() |
| 2315 | + assert not a.removesuffix(b'0') |
| 2316 | + |
| 2317 | + def rfind(b, a): |
| 2318 | + b.wait() |
| 2319 | + assert a.rfind(b'\xdd') == -1 |
| 2320 | + |
| 2321 | + def rindex(b, a): |
| 2322 | + b.wait() |
| 2323 | + try: a.rindex(b'\xdd') |
| 2324 | + except ValueError: return |
| 2325 | + assert False |
| 2326 | + |
| 2327 | + def rpartition(b, a): |
| 2328 | + b.wait() |
| 2329 | + assert not a.rpartition(b'\xdd')[0] |
| 2330 | + |
| 2331 | + def rsplit(b, a): |
| 2332 | + b.wait() |
| 2333 | + assert len(a.rsplit(b'\xdd')) == 1 |
| 2334 | + |
| 2335 | + def rstrip(b, a): |
| 2336 | + b.wait() |
| 2337 | + assert not a.rstrip(b'0') |
| 2338 | + |
| 2339 | + def split(b, a): |
| 2340 | + b.wait() |
| 2341 | + assert len(a.split(b'\xdd')) == 1 |
| 2342 | + |
| 2343 | + def splitlines(b, a): |
| 2344 | + b.wait() |
| 2345 | + l = len(a.splitlines()) |
| 2346 | + assert l > 1 or l == 0 |
| 2347 | + |
| 2348 | + def startswith(b, a): |
| 2349 | + b.wait() |
| 2350 | + assert not a.startswith(b'\xdd') |
| 2351 | + |
| 2352 | + def strip(b, a): |
| 2353 | + b.wait() |
| 2354 | + assert not a.strip(b'0') |
| 2355 | + |
| 2356 | + def repeat(b, a): |
| 2357 | + b.wait() |
| 2358 | + a * 2 |
| 2359 | + |
| 2360 | + def contains(b, a): |
| 2361 | + b.wait() |
| 2362 | + assert 0xdd not in a |
| 2363 | + |
| 2364 | + def iconcat(b, a): # MODIFIES! |
| 2365 | + c = bytearray(b'0' * 0x400000) |
| 2366 | + b.wait() |
| 2367 | + a += c |
| 2368 | + |
| 2369 | + def irepeat(b, a): # MODIFIES! |
| 2370 | + b.wait() |
| 2371 | + a *= 2 |
| 2372 | + |
| 2373 | + def subscript(b, a): |
| 2374 | + b.wait() |
| 2375 | + try: assert a[0] != 0xdd |
| 2376 | + except IndexError: pass |
| 2377 | + |
| 2378 | + def ass_subscript(b, a): # MODIFIES! |
| 2379 | + c = bytearray(b'0' * 0x400000) |
| 2380 | + b.wait() |
| 2381 | + a[:] = c |
| 2382 | + |
| 2383 | + def mod(b, a): |
| 2384 | + c = tuple(range(4096)) |
| 2385 | + b.wait() |
| 2386 | + try: a % c |
| 2387 | + except TypeError: pass |
| 2388 | + |
| 2389 | + def repr_(b, a): |
| 2390 | + b.wait() |
| 2391 | + repr(a) |
| 2392 | + |
| 2393 | + def capitalize(b, a): |
| 2394 | + b.wait() |
| 2395 | + c = a.capitalize() |
| 2396 | + assert not c or c[0] not in (0xdd, 0xcd) |
| 2397 | + |
| 2398 | + def center(b, a): |
| 2399 | + b.wait() |
| 2400 | + c = a.center(0x60000) |
| 2401 | + assert not c or c[0x20000] not in (0xdd, 0xcd) |
| 2402 | + |
| 2403 | + def expandtabs(b, a): |
| 2404 | + b.wait() |
| 2405 | + c = a.expandtabs() |
| 2406 | + assert not c or c[0] not in (0xdd, 0xcd) |
| 2407 | + |
| 2408 | + def ljust(b, a): |
| 2409 | + b.wait() |
| 2410 | + c = a.ljust(0x600000) |
| 2411 | + assert not c or c[0] not in (0xdd, 0xcd) |
| 2412 | + |
| 2413 | + def lower(b, a): |
| 2414 | + b.wait() |
| 2415 | + c = a.lower() |
| 2416 | + assert not c or c[0] not in (0xdd, 0xcd) |
| 2417 | + |
| 2418 | + def rjust(b, a): |
| 2419 | + b.wait() |
| 2420 | + c = a.rjust(0x600000) |
| 2421 | + assert not c or c[-1] not in (0xdd, 0xcd) |
| 2422 | + |
| 2423 | + def swapcase(b, a): |
| 2424 | + b.wait() |
| 2425 | + c = a.swapcase() |
| 2426 | + assert not c or c[-1] not in (0xdd, 0xcd) |
| 2427 | + |
| 2428 | + def title(b, a): |
| 2429 | + b.wait() |
| 2430 | + c = a.title() |
| 2431 | + assert not c or c[-1] not in (0xdd, 0xcd) |
| 2432 | + |
| 2433 | + def upper(b, a): |
| 2434 | + b.wait() |
| 2435 | + c = a.upper() |
| 2436 | + assert not c or c[-1] not in (0xdd, 0xcd) |
| 2437 | + |
| 2438 | + def zfill(b, a): |
| 2439 | + b.wait() |
| 2440 | + c = a.zfill(0x400000) |
| 2441 | + assert not c or c[-1] not in (0xdd, 0xcd) |
| 2442 | + |
| 2443 | + def check(funcs, a=None, *args): |
| 2444 | + if a is None: |
| 2445 | + a = bytearray(b'0' * 0x400000) |
| 2446 | + |
| 2447 | + barrier = threading.Barrier(len(funcs)) |
| 2448 | + threads = [] |
| 2449 | + |
| 2450 | + for func in funcs: |
| 2451 | + thread = threading.Thread(target=func, args=(barrier, a, *args)) |
| 2452 | + |
| 2453 | + threads.append(thread) |
| 2454 | + |
| 2455 | + with threading_helper.start_threads(threads): |
| 2456 | + pass |
| 2457 | + |
| 2458 | + for thread in threads: |
| 2459 | + threading_helper.join_thread(thread) |
| 2460 | + |
| 2461 | + # hard errors |
| 2462 | + |
| 2463 | + check([clear] + [reduce] * 10) |
| 2464 | + check([clear] + [reduceex2] * 10) |
| 2465 | + check([clear] + [append1] * 10) |
| 2466 | + check([clear] * 10) |
| 2467 | + check([clear] + [count0] * 10) |
| 2468 | + check([clear] + [decode] * 10) |
| 2469 | + check([clear] + [extend] * 10) |
| 2470 | + check([clear] + [find] * 10) |
| 2471 | + check([clear] + [hex] * 10) |
| 2472 | + check([clear] + [insert1] * 10) |
| 2473 | + check([clear] + [join] * 10) |
| 2474 | + check([clear] + [pop1] * 10) |
| 2475 | + check([clear] + [remove] * 10) |
| 2476 | + check([clear] + [replace] * 10) |
| 2477 | + check([clear] + [reverse] * 10) |
| 2478 | + check([clear, clear2] + [maketrans] * 10, bytearray(range(128)), bytearray(range(128))) |
| 2479 | + check([clear] + [translate] * 10, None, bytearray.maketrans(bytearray(range(128)), bytearray(range(128)))) |
| 2480 | + |
| 2481 | + check([clear] + [repeat] * 10) |
| 2482 | + check([clear] + [iconcat] * 10) |
| 2483 | + check([clear] + [irepeat] * 10) |
| 2484 | + check([clear] + [ass_subscript] * 10) |
| 2485 | + check([clear] + [repr_] * 10) |
| 2486 | + |
| 2487 | + # value errors |
| 2488 | + |
| 2489 | + check([clear] + [reduceex3] * 10, bytearray(b'a' * 0x40000)) |
| 2490 | + check([clear] + [copy] * 10) |
| 2491 | + check([clear] + [endswith] * 10) |
| 2492 | + check([clear] + [index] * 10) |
| 2493 | + check([clear] + [lstrip] * 10) |
| 2494 | + check([clear] + [partition] * 10) |
| 2495 | + check([clear] + [removeprefix] * 10, bytearray(b'0')) |
| 2496 | + check([clear] + [removesuffix] * 10, bytearray(b'0')) |
| 2497 | + check([clear] + [rfind] * 10) |
| 2498 | + check([clear] + [rindex] * 10) |
| 2499 | + check([clear] + [rpartition] * 10) |
| 2500 | + check([clear] + [rsplit] * 10, bytearray(b'0' * 0x4000)) |
| 2501 | + check([clear] + [rstrip] * 10) |
| 2502 | + check([clear] + [split] * 10, bytearray(b'0' * 0x4000)) |
| 2503 | + check([clear] + [splitlines] * 10, bytearray(b'\n' * 0x400)) |
| 2504 | + check([clear] + [startswith] * 10) |
| 2505 | + check([clear] + [strip] * 10) |
| 2506 | + |
| 2507 | + check([clear] + [contains] * 10) |
| 2508 | + check([clear] + [subscript] * 10) |
| 2509 | + check([clear] + [mod] * 10, bytearray(b'%d' * 4096)) |
| 2510 | + |
| 2511 | + check([clear] + [capitalize] * 10, bytearray(b'a' * 0x40000)) |
| 2512 | + check([clear] + [center] * 10, bytearray(b'a' * 0x40000)) |
| 2513 | + check([clear] + [expandtabs] * 10, bytearray(b'0\t' * 4096)) |
| 2514 | + check([clear] + [ljust] * 10, bytearray(b'0' * 0x400000)) |
| 2515 | + check([clear] + [lower] * 10, bytearray(b'A' * 0x400000)) |
| 2516 | + check([clear] + [rjust] * 10, bytearray(b'0' * 0x400000)) |
| 2517 | + check([clear] + [swapcase] * 10, bytearray(b'aA' * 0x200000)) |
| 2518 | + check([clear] + [title] * 10, bytearray(b'aA' * 0x200000)) |
| 2519 | + check([clear] + [upper] * 10, bytearray(b'a' * 0x400000)) |
| 2520 | + check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000)) |
| 2521 | + |
| 2522 | + |
2188 | 2523 | if __name__ == "__main__":
|
2189 | 2524 | unittest.main()
|
0 commit comments