Skip to content

Commit 79f89e6

Browse files
authored
bpo-39421: Fix posible crash in heapq with custom comparison operators (pythonGH-18118)
* bpo-39421: Fix posible crash in heapq with custom comparison operators * fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators * fixup! fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators
1 parent 13bc139 commit 79f89e6

File tree

3 files changed

+59
-9
lines changed

3 files changed

+59
-9
lines changed

Lib/test/test_heapq.py

+31
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,37 @@ def test_heappop_mutating_heap(self):
432432
with self.assertRaises((IndexError, RuntimeError)):
433433
self.module.heappop(heap)
434434

435+
def test_comparison_operator_modifiying_heap(self):
436+
# See bpo-39421: Strong references need to be taken
437+
# when comparing objects as they can alter the heap
438+
class EvilClass(int):
439+
def __lt__(self, o):
440+
heap.clear()
441+
return NotImplemented
442+
443+
heap = []
444+
self.module.heappush(heap, EvilClass(0))
445+
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
446+
447+
def test_comparison_operator_modifiying_heap_two_heaps(self):
448+
449+
class h(int):
450+
def __lt__(self, o):
451+
list2.clear()
452+
return NotImplemented
453+
454+
class g(int):
455+
def __lt__(self, o):
456+
list1.clear()
457+
return NotImplemented
458+
459+
list1, list2 = [], []
460+
461+
self.module.heappush(list1, h(0))
462+
self.module.heappush(list2, g(0))
463+
464+
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
465+
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
435466

436467
class TestErrorHandlingPython(TestErrorHandling, TestCase):
437468
module = py_heapq
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix possible crashes when operating with the functions in the :mod:`heapq`
2+
module and custom comparison operators.

Modules/_heapqmodule.c

+26-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
3636
while (pos > startpos) {
3737
parentpos = (pos - 1) >> 1;
3838
parent = arr[parentpos];
39+
Py_INCREF(newitem);
40+
Py_INCREF(parent);
3941
cmp = PyObject_RichCompareBool(newitem, parent, Py_LT);
42+
Py_DECREF(parent);
43+
Py_DECREF(newitem);
4044
if (cmp < 0)
4145
return -1;
4246
if (size != PyList_GET_SIZE(heap)) {
@@ -78,10 +82,13 @@ siftup(PyListObject *heap, Py_ssize_t pos)
7882
/* Set childpos to index of smaller child. */
7983
childpos = 2*pos + 1; /* leftmost child position */
8084
if (childpos + 1 < endpos) {
81-
cmp = PyObject_RichCompareBool(
82-
arr[childpos],
83-
arr[childpos + 1],
84-
Py_LT);
85+
PyObject* a = arr[childpos];
86+
PyObject* b = arr[childpos + 1];
87+
Py_INCREF(a);
88+
Py_INCREF(b);
89+
cmp = PyObject_RichCompareBool(a, b, Py_LT);
90+
Py_DECREF(a);
91+
Py_DECREF(b);
8592
if (cmp < 0)
8693
return -1;
8794
childpos += ((unsigned)cmp ^ 1); /* increment when cmp==0 */
@@ -264,7 +271,10 @@ _heapq_heappushpop_impl(PyObject *module, PyObject *heap, PyObject *item)
264271
return item;
265272
}
266273

267-
cmp = PyObject_RichCompareBool(PyList_GET_ITEM(heap, 0), item, Py_LT);
274+
PyObject* top = PyList_GET_ITEM(heap, 0);
275+
Py_INCREF(top);
276+
cmp = PyObject_RichCompareBool(top, item, Py_LT);
277+
Py_DECREF(top);
268278
if (cmp < 0)
269279
return NULL;
270280
if (cmp == 0) {
@@ -420,7 +430,11 @@ siftdown_max(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
420430
while (pos > startpos) {
421431
parentpos = (pos - 1) >> 1;
422432
parent = arr[parentpos];
433+
Py_INCREF(parent);
434+
Py_INCREF(newitem);
423435
cmp = PyObject_RichCompareBool(parent, newitem, Py_LT);
436+
Py_DECREF(parent);
437+
Py_DECREF(newitem);
424438
if (cmp < 0)
425439
return -1;
426440
if (size != PyList_GET_SIZE(heap)) {
@@ -462,10 +476,13 @@ siftup_max(PyListObject *heap, Py_ssize_t pos)
462476
/* Set childpos to index of smaller child. */
463477
childpos = 2*pos + 1; /* leftmost child position */
464478
if (childpos + 1 < endpos) {
465-
cmp = PyObject_RichCompareBool(
466-
arr[childpos + 1],
467-
arr[childpos],
468-
Py_LT);
479+
PyObject* a = arr[childpos + 1];
480+
PyObject* b = arr[childpos];
481+
Py_INCREF(a);
482+
Py_INCREF(b);
483+
cmp = PyObject_RichCompareBool(a, b, Py_LT);
484+
Py_DECREF(a);
485+
Py_DECREF(b);
469486
if (cmp < 0)
470487
return -1;
471488
childpos += ((unsigned)cmp ^ 1); /* increment when cmp==0 */

0 commit comments

Comments
 (0)