Skip to content

Commit c9380ae

Browse files
authored
gh-141510: Check argument in PyDict_Contains() (#145083)
PyDict_Contains() and PyDict_ContainsString() now fail with SystemError if the first argument is not a dict, frozendict, dict subclass or frozendict subclass.
1 parent f1f61bf commit c9380ae

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

Lib/test/test_capi/test_dict.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def test_dict_getitemwitherror(self):
223223
# CRASHES getitem(NULL, 'a')
224224

225225
def test_dict_contains(self):
226+
# Test PyDict_Contains()
226227
contains = _testlimitedcapi.dict_contains
227228
dct = {'a': 1, '\U0001f40d': 2}
228229
self.assertTrue(contains(dct, 'a'))
@@ -235,11 +236,12 @@ def test_dict_contains(self):
235236

236237
self.assertRaises(TypeError, contains, {}, []) # unhashable
237238
# CRASHES contains({}, NULL)
238-
# CRASHES contains(UserDict(), 'a')
239-
# CRASHES contains(42, 'a')
239+
self.assertRaises(SystemError, contains, UserDict(), 'a')
240+
self.assertRaises(SystemError, contains, 42, 'a')
240241
# CRASHES contains(NULL, 'a')
241242

242243
def test_dict_contains_string(self):
244+
# Test PyDict_ContainsString()
243245
contains_string = _testcapi.dict_containsstring
244246
dct = {'a': 1, '\U0001f40d': 2}
245247
self.assertTrue(contains_string(dct, b'a'))
@@ -251,6 +253,8 @@ def test_dict_contains_string(self):
251253
self.assertTrue(contains_string(dct2, b'a'))
252254
self.assertFalse(contains_string(dct2, b'b'))
253255

256+
self.assertRaises(SystemError, contains_string, UserDict(), 'a')
257+
self.assertRaises(SystemError, contains_string, 42, 'a')
254258
# CRASHES contains({}, NULL)
255259
# CRASHES contains(NULL, b'a')
256260

Objects/dictobject.c

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ static PyObject* frozendict_new(PyTypeObject *type, PyObject *args,
140140
PyObject *kwds);
141141
static PyObject* dict_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
142142
static int dict_merge(PyObject *a, PyObject *b, int override);
143+
static int dict_contains(PyObject *op, PyObject *key);
143144
static int dict_merge_from_seq2(PyObject *d, PyObject *seq2, int override);
144145

145146

@@ -4126,7 +4127,7 @@ dict_merge(PyObject *a, PyObject *b, int override)
41264127

41274128
for (key = PyIter_Next(iter); key; key = PyIter_Next(iter)) {
41284129
if (override != 1) {
4129-
status = PyDict_Contains(a, key);
4130+
status = dict_contains(a, key);
41304131
if (status != 0) {
41314132
if (status > 0) {
41324133
if (override == 0) {
@@ -4484,7 +4485,7 @@ static PyObject *
44844485
dict___contains___impl(PyDictObject *self, PyObject *key)
44854486
/*[clinic end generated code: output=1b314e6da7687dae input=fe1cb42ad831e820]*/
44864487
{
4487-
int contains = PyDict_Contains((PyObject *)self, key);
4488+
int contains = dict_contains((PyObject *)self, key);
44884489
if (contains < 0) {
44894490
return NULL;
44904491
}
@@ -4984,9 +4985,8 @@ static PyMethodDef mapp_methods[] = {
49844985
{NULL, NULL} /* sentinel */
49854986
};
49864987

4987-
/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */
4988-
int
4989-
PyDict_Contains(PyObject *op, PyObject *key)
4988+
static int
4989+
dict_contains(PyObject *op, PyObject *key)
49904990
{
49914991
Py_hash_t hash = _PyObject_HashFast(key);
49924992
if (hash == -1) {
@@ -4997,6 +4997,18 @@ PyDict_Contains(PyObject *op, PyObject *key)
49974997
return _PyDict_Contains_KnownHash(op, key, hash);
49984998
}
49994999

5000+
/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */
5001+
int
5002+
PyDict_Contains(PyObject *op, PyObject *key)
5003+
{
5004+
if (!PyAnyDict_Check(op)) {
5005+
PyErr_BadInternalCall();
5006+
return -1;
5007+
}
5008+
5009+
return dict_contains(op, key);
5010+
}
5011+
50005012
int
50015013
PyDict_ContainsString(PyObject *op, const char *key)
50025014
{
@@ -5013,7 +5025,7 @@ PyDict_ContainsString(PyObject *op, const char *key)
50135025
int
50145026
_PyDict_Contains_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash)
50155027
{
5016-
PyDictObject *mp = (PyDictObject *)op;
5028+
PyDictObject *mp = _PyAnyDict_CAST(op);
50175029
PyObject *value;
50185030
Py_ssize_t ix;
50195031

@@ -5042,7 +5054,7 @@ static PySequenceMethods dict_as_sequence = {
50425054
0, /* sq_slice */
50435055
0, /* sq_ass_item */
50445056
0, /* sq_ass_slice */
5045-
PyDict_Contains, /* sq_contains */
5057+
dict_contains, /* sq_contains */
50465058
0, /* sq_inplace_concat */
50475059
0, /* sq_inplace_repeat */
50485060
};
@@ -6292,7 +6304,7 @@ dictkeys_contains(PyObject *self, PyObject *obj)
62926304
_PyDictViewObject *dv = (_PyDictViewObject *)self;
62936305
if (dv->dv_dict == NULL)
62946306
return 0;
6295-
return PyDict_Contains((PyObject *)dv->dv_dict, obj);
6307+
return dict_contains((PyObject *)dv->dv_dict, obj);
62966308
}
62976309

62986310
static PySequenceMethods dictkeys_as_sequence = {

0 commit comments

Comments
 (0)