Skip to content

Commit ba66a2d

Browse files
committed
Capture all BSON decode errors and wrap with InvalidBSON. PYTHON-494
1 parent fecbbee commit ba66a2d

File tree

3 files changed

+107
-61
lines changed

3 files changed

+107
-61
lines changed

bson/__init__.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,6 @@ def _dict_to_bson(dict, check_keys, uuid_subtype, top_level=True):
484484
_dict_to_bson = _cbson._dict_to_bson
485485

486486

487-
488487
def decode_all(data, as_class=dict,
489488
tz_aware=True, uuid_subtype=OLD_UUID_SUBTYPE):
490489
"""Decode BSON data to multiple documents.
@@ -504,17 +503,24 @@ def decode_all(data, as_class=dict,
504503
docs = []
505504
position = 0
506505
end = len(data) - 1
507-
while position < end:
508-
obj_size = struct.unpack("<i", data[position:position + 4])[0]
509-
if len(data) - position < obj_size:
510-
raise InvalidBSON("objsize too large")
511-
if data[position + obj_size - 1:position + obj_size] != ZERO:
512-
raise InvalidBSON("bad eoo")
513-
elements = data[position + 4:position + obj_size - 1]
514-
position += obj_size
515-
docs.append(_elements_to_dict(elements, as_class,
516-
tz_aware, uuid_subtype))
517-
return docs
506+
try:
507+
while position < end:
508+
obj_size = struct.unpack("<i", data[position:position + 4])[0]
509+
if len(data) - position < obj_size:
510+
raise InvalidBSON("objsize too large")
511+
if data[position + obj_size - 1:position + obj_size] != ZERO:
512+
raise InvalidBSON("bad eoo")
513+
elements = data[position + 4:position + obj_size - 1]
514+
position += obj_size
515+
docs.append(_elements_to_dict(elements, as_class,
516+
tz_aware, uuid_subtype))
517+
return docs
518+
except InvalidBSON:
519+
raise
520+
except Exception:
521+
# Change exception type to InvalidBSON but preserve traceback.
522+
exc_type, exc_value, exc_tb = sys.exc_info()
523+
raise InvalidBSON, InvalidBSON(str(exc_value)), exc_tb
518524
if _use_c:
519525
decode_all = _cbson.decode_all
520526

bson/_cbsonmodule.c

+58-47
Original file line numberDiff line numberDiff line change
@@ -1347,8 +1347,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
13471347
unsigned char tz_aware, unsigned char uuid_subtype) {
13481348
struct module_state *state = GETSTATE(self);
13491349

1350-
PyObject* value;
1351-
PyObject* error;
1350+
PyObject* value = NULL;
13521351
switch (type) {
13531352
case 1:
13541353
{
@@ -1358,9 +1357,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
13581357
}
13591358
memcpy(&d, buffer + *position, 8);
13601359
value = PyFloat_FromDouble(d);
1361-
if (!value) {
1362-
return NULL;
1363-
}
13641360
*position += 8;
13651361
break;
13661362
}
@@ -1373,9 +1369,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
13731369
}
13741370
*position += 4;
13751371
value = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
1376-
if (!value) {
1377-
return NULL;
1378-
}
13791372
*position += value_length + 1;
13801373
break;
13811374
}
@@ -1389,10 +1382,10 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
13891382
}
13901383
value = elements_to_dict(self, buffer + *position + 4,
13911384
size - 5, as_class, tz_aware, uuid_subtype);
1385+
13921386
if (!value) {
1393-
return NULL;
1387+
goto invalid;
13941388
}
1395-
13961389
/* Decoding for DBRefs */
13971390
collection = PyDict_GetItemString(value, "$ref");
13981391
if (collection) { /* DBRef */
@@ -1428,9 +1421,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
14281421
Py_DECREF(id);
14291422
Py_DECREF(collection);
14301423
Py_DECREF(database);
1431-
if (!value) {
1432-
return NULL;
1433-
}
14341424
}
14351425

14361426
*position += size;
@@ -1450,7 +1440,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
14501440

14511441
value = PyList_New(0);
14521442
if (!value) {
1453-
return NULL;
1443+
goto invalid;
14541444
}
14551445
while (*position < end) {
14561446
PyObject* to_append;
@@ -1467,7 +1457,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
14671457
max - (int)key_size, as_class, tz_aware, uuid_subtype);
14681458
if (!to_append) {
14691459
Py_DECREF(value);
1470-
return NULL;
1460+
goto invalid;
14711461
}
14721462
PyList_Append(value, to_append);
14731463
Py_DECREF(to_append);
@@ -1506,20 +1496,20 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
15061496
}
15071497
#endif
15081498
if (!data) {
1509-
return NULL;
1499+
goto invalid;
15101500
}
15111501
if ((subtype == 3 || subtype == 4) && state->UUID) { // Encode as UUID, not Binary
15121502
PyObject* kwargs;
15131503
PyObject* args = PyTuple_New(0);
15141504
if (!args) {
15151505
Py_DECREF(data);
1516-
return NULL;
1506+
goto invalid;
15171507
}
15181508
kwargs = PyDict_New();
15191509
if (!kwargs) {
15201510
Py_DECREF(data);
15211511
Py_DECREF(args);
1522-
return NULL;
1512+
goto invalid;
15231513
}
15241514

15251515
assert(length == 16); // UUID should always be 16 bytes
@@ -1553,10 +1543,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
15531543
Py_DECREF(args);
15541544
Py_DECREF(kwargs);
15551545
Py_DECREF(data);
1556-
if (!value) {
1557-
return NULL;
1558-
}
1559-
15601546
*position += length + 5;
15611547
break;
15621548

@@ -1579,9 +1565,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
15791565
value = PyObject_CallFunctionObjArgs(state->Binary, data, st, NULL);
15801566
Py_DECREF(st);
15811567
Py_DECREF(data);
1582-
if (!value) {
1583-
return NULL;
1584-
}
15851568
*position += length + 5;
15861569
break;
15871570
}
@@ -1602,9 +1585,6 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
16021585
#else
16031586
value = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12);
16041587
#endif
1605-
if (!value) {
1606-
return NULL;
1607-
}
16081588
*position += 12;
16091589
break;
16101590
}
@@ -1631,29 +1611,29 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
16311611
}
16321612

16331613
if (!naive) {
1634-
return NULL;
1614+
goto invalid;
16351615
}
16361616
replace = PyObject_GetAttrString(naive, "replace");
16371617
Py_DECREF(naive);
16381618
if (!replace) {
1639-
return NULL;
1619+
goto invalid;
16401620
}
16411621
args = PyTuple_New(0);
16421622
if (!args) {
16431623
Py_DECREF(replace);
1644-
return NULL;
1624+
goto invalid;
16451625
}
16461626
kwargs = PyDict_New();
16471627
if (!kwargs) {
16481628
Py_DECREF(replace);
16491629
Py_DECREF(args);
1650-
return NULL;
1630+
goto invalid;
16511631
}
16521632
if (PyDict_SetItemString(kwargs, "tzinfo", state->UTC) == -1) {
16531633
Py_DECREF(replace);
16541634
Py_DECREF(args);
16551635
Py_DECREF(kwargs);
1656-
return NULL;
1636+
goto invalid;
16571637
}
16581638
value = PyObject_Call(replace, args, kwargs);
16591639
Py_DECREF(replace);
@@ -1672,7 +1652,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
16721652
}
16731653
pattern = PyUnicode_DecodeUTF8(buffer + *position, pattern_length, "strict");
16741654
if (!pattern) {
1675-
return NULL;
1655+
goto invalid;
16761656
}
16771657
*position += (int)pattern_length + 1;
16781658
if ((flags_length = strlen(buffer + *position)) > BSON_MAX_SIZE) {
@@ -1718,14 +1698,14 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
17181698
collection = PyUnicode_DecodeUTF8(buffer + *position,
17191699
coll_length, "strict");
17201700
if (!collection) {
1721-
return NULL;
1701+
goto invalid;
17221702
}
17231703
*position += (int)coll_length + 1;
17241704

17251705
id = PyObject_CallFunction(state->ObjectId, "s#", buffer + *position, 12);
17261706
if (!id) {
17271707
Py_DECREF(collection);
1728-
return NULL;
1708+
goto invalid;
17291709
}
17301710
*position += 12;
17311711
value = PyObject_CallFunctionObjArgs(state->DBRef, collection, id, NULL);
@@ -1743,7 +1723,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
17431723
*position += 4;
17441724
code = PyUnicode_DecodeUTF8(buffer + *position, value_length, "strict");
17451725
if (!code) {
1746-
return NULL;
1726+
goto invalid;
17471727
}
17481728
*position += value_length + 1;
17491729
value = PyObject_CallFunctionObjArgs(state->Code, code, NULL, NULL);
@@ -1764,7 +1744,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
17641744
}
17651745
code = PyUnicode_DecodeUTF8(buffer + *position, code_length, "strict");
17661746
if (!code) {
1767-
return NULL;
1747+
goto invalid;
17681748
}
17691749
*position += (int)code_length + 1;
17701750

@@ -1773,7 +1753,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
17731753
(PyObject*)&PyDict_Type, tz_aware, uuid_subtype);
17741754
if (!scope) {
17751755
Py_DECREF(code);
1776-
return NULL;
1756+
goto invalid;
17771757
}
17781758
*position += scope_size;
17791759

@@ -1795,7 +1775,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
17951775
value = PyInt_FromLong(i);
17961776
#endif
17971777
if (!value) {
1798-
return NULL;
1778+
goto invalid;
17991779
}
18001780
*position += 4;
18011781
break;
@@ -1810,7 +1790,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
18101790
memcpy(&time, buffer + *position + 4, 4);
18111791
value = PyObject_CallFunction(state->Timestamp, "II", time, inc);
18121792
if (!value) {
1813-
return NULL;
1793+
goto invalid;
18141794
}
18151795
*position += 8;
18161796
break;
@@ -1824,7 +1804,7 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
18241804
memcpy(&ll, buffer + *position, 8);
18251805
value = PyLong_FromLongLong(ll);
18261806
if (!value) {
1827-
return NULL;
1807+
goto invalid;
18281808
}
18291809
*position += 8;
18301810
break;
@@ -1850,14 +1830,45 @@ static PyObject* get_value(PyObject* self, const char* buffer, int* position,
18501830
return NULL;
18511831
}
18521832
}
1853-
return value;
1833+
1834+
if (value) {
1835+
return value;
1836+
}
18541837

18551838
invalid:
18561839

1857-
error = _error("InvalidBSON");
1858-
if (error) {
1859-
PyErr_SetNone(error);
1860-
Py_DECREF(error);
1840+
/* Wrap any non-InvalidBSON errors in InvalidBSON. */
1841+
if (PyErr_Occurred()) {
1842+
/* Calling _error clears the error state, so fetch it first. */
1843+
PyObject *etype, *evalue, *etrace, *InvalidBSON;
1844+
PyErr_Fetch(&etype, &evalue, &etrace);
1845+
InvalidBSON = _error("InvalidBSON");
1846+
if (InvalidBSON) {
1847+
if (!PyErr_GivenExceptionMatches(etype, InvalidBSON)) {
1848+
/* Raise InvalidBSON(str(e)). */
1849+
PyObject *msg = NULL;
1850+
Py_DECREF(etype);
1851+
etype = InvalidBSON;
1852+
1853+
if (evalue) {
1854+
msg = PyObject_Str(evalue);
1855+
Py_DECREF(evalue);
1856+
evalue = msg;
1857+
}
1858+
PyErr_NormalizeException(&etype, &evalue, &etrace);
1859+
Py_XDECREF(msg);
1860+
}
1861+
}
1862+
/* Steals references to args. */
1863+
PyErr_Restore(etype, evalue, etrace);
1864+
Py_XDECREF(InvalidBSON);
1865+
return NULL;
1866+
} else {
1867+
PyObject *InvalidBSON = _error("InvalidBSON");
1868+
if (InvalidBSON) {
1869+
PyErr_SetNone(InvalidBSON);
1870+
Py_DECREF(InvalidBSON);
1871+
}
18611872
}
18621873
return NULL;
18631874
}

test/test_bson.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
"""Test the bson module."""
1818

19-
import unittest
2019
import datetime
2120
import re
2221
import sys
22+
import traceback
23+
import unittest
2324
try:
2425
import uuid
2526
should_test_uuid = True
@@ -41,7 +42,8 @@
4142
from bson.son import SON
4243
from bson.timestamp import Timestamp
4344
from bson.errors import (InvalidDocument,
44-
InvalidStringData)
45+
InvalidStringData,
46+
InvalidBSON)
4547
from bson.max_key import MaxKey
4648
from bson.min_key import MinKey
4749
from bson.tz_util import (FixedOffset,
@@ -448,5 +450,32 @@ def test_ordered_dict(self):
448450
d = OrderedDict([("one", 1), ("two", 2), ("three", 3), ("four", 4)])
449451
self.assertEqual(d, BSON.encode(d).decode(as_class=OrderedDict))
450452

453+
def test_exception_wrapping(self):
454+
# No matter what exception is raised while trying to decode BSON,
455+
# the final exception always matches InvalidBSON and the original
456+
# is traceback preserved.
457+
458+
# Invalid Python regex, though valid PCRE: {'r': /[\w-\.]/}
459+
# Will cause an error in re.compile().
460+
bad_doc = b('"\x00\x00\x00\x07_id\x00R\x013\xd4S1\xe3\xd3\xd6Sgs'
461+
'\x0br\x00[\\w-\\.]\x00\x00\x00')
462+
463+
try:
464+
decode_all(bad_doc)
465+
except InvalidBSON:
466+
exc_type, exc_value, exc_tb = sys.exc_info()
467+
# Original re error was captured and wrapped in InvalidBSON.
468+
self.assertEqual(exc_value.args[0], 'bad character range')
469+
470+
# Traceback includes bson module's call into re module.
471+
for filename, lineno, fname, text in traceback.extract_tb(exc_tb):
472+
if filename.endswith('re.py') and fname == 'compile':
473+
# Traceback was correctly preserved.
474+
break
475+
else:
476+
self.fail('Traceback not captured')
477+
else:
478+
self.fail('InvalidBSON not raised')
479+
451480
if __name__ == "__main__":
452481
unittest.main()

0 commit comments

Comments
 (0)