Skip to content

Commit f2234ae

Browse files
committed
Fix itertools.count in free-threading mode
Thread safety in count_next. count_next has two modes. slow mode (obj->cnt set to PY_SSIZE_T_MAX), which now uses the object mutex (only if GIL is disabled) and fast mode, which is either simple cnt++ if GIL is enabled, or uses atomic_compare_exchange if GIL is disabled.
1 parent bf17986 commit f2234ae

File tree

3 files changed

+60
-14
lines changed

3 files changed

+60
-14
lines changed

Lib/test/test_itertools.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,26 @@ def test_count_with_stride(self):
590590
self.assertEqual(type(next(c)), int)
591591
self.assertEqual(type(next(c)), float)
592592

593+
def test_count_threading(self, step=1):
594+
# this test verifies multithreading consistency, which is
595+
# mostly for testing builds without GIL, but nice to test anyway
596+
count_to = 10_000
597+
num_threads = 10
598+
c = count(step=step)
599+
def counting_thread():
600+
for i in range(count_to):
601+
next(c)
602+
threads = []
603+
for i in range(num_threads):
604+
thread = threading.Thread(target=counting_thread)
605+
thread.start()
606+
threads.append(thread)
607+
[thread.join() for thread in threads]
608+
self.assertEqual(next(c), count_to * num_threads * step)
609+
610+
def test_count_with_stride_threading(self):
611+
self.test_count_threading(5)
612+
593613
def test_cycle(self):
594614
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
595615
self.assertEqual(list(cycle('')), [])

Modules/itertoolsmodule.c

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#include "Python.h"
2-
#include "pycore_call.h" // _PyObject_CallNoArgs()
3-
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
4-
#include "pycore_long.h" // _PyLong_GetZero()
5-
#include "pycore_moduleobject.h" // _PyModule_GetState()
6-
#include "pycore_typeobject.h" // _PyType_GetModuleState()
7-
#include "pycore_object.h" // _PyObject_GC_TRACK()
8-
#include "pycore_tuple.h" // _PyTuple_ITEMS()
2+
#include "pycore_call.h" // _PyObject_CallNoArgs()
3+
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
4+
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
5+
#include "pycore_long.h" // _PyLong_GetZero()
6+
#include "pycore_moduleobject.h" // _PyModule_GetState()
7+
#include "pycore_typeobject.h" // _PyType_GetModuleState()
8+
#include "pycore_object.h" // _PyObject_GC_TRACK()
9+
#include "pycore_tuple.h" // _PyTuple_ITEMS()
910

10-
#include <stddef.h> // offsetof()
11+
#include <stddef.h> // offsetof()
1112

1213
/* Itertools module written and maintained
1314
by Raymond D. Hettinger <python@rcn.com>
@@ -3254,7 +3255,7 @@ fast_mode: when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
32543255
32553256
assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
32563257
Advances with: cnt += 1
3257-
When count hits Y_SSIZE_T_MAX, switch to slow_mode.
3258+
When count hits PY_SSIZE_T_MAX, switch to slow_mode.
32583259
32593260
slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
32603261
@@ -3386,7 +3387,7 @@ count_nextlong(countobject *lz)
33863387

33873388
long_cnt = lz->long_cnt;
33883389
if (long_cnt == NULL) {
3389-
/* Switch to slow_mode */
3390+
/* Switching from fast mode */
33903391
long_cnt = PyLong_FromSsize_t(PY_SSIZE_T_MAX);
33913392
if (long_cnt == NULL)
33923393
return NULL;
@@ -3403,9 +3404,35 @@ count_nextlong(countobject *lz)
34033404
static PyObject *
34043405
count_next(countobject *lz)
34053406
{
3406-
if (lz->cnt == PY_SSIZE_T_MAX)
3407-
return count_nextlong(lz);
3408-
return PyLong_FromSsize_t(lz->cnt++);
3407+
PyObject *returned;
3408+
Py_ssize_t cnt;
3409+
3410+
cnt = FT_ATOMIC_LOAD_SSIZE_RELAXED(lz->cnt);
3411+
for (;;) {
3412+
if (cnt == PY_SSIZE_T_MAX) {
3413+
/* slow mode */
3414+
Py_BEGIN_CRITICAL_SECTION(lz);
3415+
returned = count_nextlong(lz);
3416+
Py_END_CRITICAL_SECTION();
3417+
return returned;
3418+
}
3419+
#ifdef Py_GIL_DISABLED
3420+
/* thread-safe fast version (increment by one).
3421+
* If lz->cnt changed between the pervious read and now,
3422+
* that means another thread got in our way. In this case,
3423+
* update cnt to new value of lz->cnt, and try again.
3424+
* Otherwise, (no other thread updated lz->cnt),
3425+
* atomically update lz->cnt with the incremented value and
3426+
* then return cnt (the previous value)
3427+
*/
3428+
if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
3429+
return PyLong_FromSsize_t(cnt);
3430+
}
3431+
#else
3432+
/* fast mode when GIL is enabled */
3433+
return PyLong_FromSsize_t(lz->cnt++);
3434+
#endif
3435+
}
34093436
}
34103437

34113438
static PyObject *

Tools/tsan/suppressions_free_threading.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ race_top:_Py_dict_lookup_threadsafe
5656
race_top:_imp_release_lock
5757
race_top:_multiprocessing_SemLock_acquire_impl
5858
race_top:builtin_compile_impl
59-
race_top:count_next
6059
race_top:dictiter_new
6160
race_top:dictresize
6261
race_top:insert_to_emptydict

0 commit comments

Comments
 (0)