|
21 | 21 | """ |
22 | 22 |
|
23 | 23 | import re |
| 24 | +import sys |
24 | 25 | from datetime import datetime, timedelta |
| 26 | +from json import loads as json_loads |
25 | 27 | from warnings import catch_warnings |
26 | 28 |
|
27 | 29 |
|
@@ -1197,3 +1199,162 @@ def test(warning): |
1197 | 1199 | context = AssertWarnsRegexContext(warning_type, regex, msg_fmt) |
1198 | 1200 | context.add_test(test) |
1199 | 1201 | return context |
| 1202 | + |
| 1203 | + |
| 1204 | +if sys.version_info >= (3,): |
| 1205 | + _Str = str |
| 1206 | +else: |
| 1207 | + _Str = unicode |
| 1208 | + |
| 1209 | + |
| 1210 | +def assert_json_subset(first, second): |
| 1211 | + """Assert that a JSON object or array is a subset of another JSON object |
| 1212 | + or array. |
| 1213 | +
|
| 1214 | + The first JSON object or array must be supplied as a JSON-compatible |
| 1215 | + dict or list, the JSON object or array to check must be a string, an |
| 1216 | + UTF-8 bytes object, or a JSON-compatible list or dict. |
| 1217 | +
|
| 1218 | + A JSON non-object, non-array value is the subset of another JSON value, |
| 1219 | + if they are equal. |
| 1220 | +
|
| 1221 | + A JSON object is the subset of another JSON object if for each name/value |
| 1222 | + pair in the former there is a name/value pair in the latter with the same |
| 1223 | + name. Additionally the value of the former pair must be a subset of the |
| 1224 | + value of the latter pair. |
| 1225 | +
|
| 1226 | + A JSON array is the subset of another JSON array, if they have the same |
| 1227 | + number of elements and each element in the former is a subset of the |
| 1228 | + corresponding element in the latter. |
| 1229 | +
|
| 1230 | + >>> assert_json_subset({}, '{}') |
| 1231 | + >>> assert_json_subset({}, '{"foo": "bar"}') |
| 1232 | + >>> assert_json_subset({"foo": "bar"}, '{}') |
| 1233 | + Traceback (most recent call last): |
| 1234 | + ... |
| 1235 | + AssertionError: name 'foo' missing |
| 1236 | + >>> assert_json_subset([1, 2], '[1, 2]') |
| 1237 | + >>> assert_json_subset([2, 1], '[1, 2]') |
| 1238 | + Traceback (most recent call last): |
| 1239 | + ... |
| 1240 | + AssertionError: element #0 differs: 2 != 1 |
| 1241 | + >>> assert_json_subset([{}], '[{"foo": "bar"}]') |
| 1242 | + >>> assert_json_subset({}, "INVALID JSON") |
| 1243 | + Traceback (most recent call last): |
| 1244 | + ... |
| 1245 | + TypeError: invalid JSON |
| 1246 | + """ |
| 1247 | + |
| 1248 | + if not isinstance(second, (dict, list, str, bytes)): |
| 1249 | + raise TypeError("second must be dict, list, str, or bytes") |
| 1250 | + if isinstance(second, bytes): |
| 1251 | + second = second.decode("utf-8") |
| 1252 | + if isinstance(second, _Str): |
| 1253 | + parsed_second = json_loads(second) |
| 1254 | + else: |
| 1255 | + parsed_second = second |
| 1256 | + |
| 1257 | + if not isinstance(parsed_second, (dict, list)): |
| 1258 | + raise AssertionError("second must decode to dict or list, not {}". |
| 1259 | + format(type(parsed_second))) |
| 1260 | + |
| 1261 | + comparer = _JSONComparer(_JSONPath("$"), first, parsed_second) |
| 1262 | + comparer.assert_() |
| 1263 | + |
| 1264 | + |
| 1265 | +class _JSONComparer: |
| 1266 | + def __init__(self, path, expected, actual): |
| 1267 | + self._path = path |
| 1268 | + self._expected = expected |
| 1269 | + self._actual = actual |
| 1270 | + |
| 1271 | + def assert_(self): |
| 1272 | + self._assert_types_are_equal() |
| 1273 | + if isinstance(self._expected, dict): |
| 1274 | + self._assert_dicts_equal() |
| 1275 | + elif isinstance(self._expected, list): |
| 1276 | + self._assert_arrays_equal() |
| 1277 | + else: |
| 1278 | + self._assert_fundamental_values_equal() |
| 1279 | + |
| 1280 | + def _assert_types_are_equal(self): |
| 1281 | + if self._types_differ(): |
| 1282 | + self._raise_different_values() |
| 1283 | + |
| 1284 | + def _types_differ(self): |
| 1285 | + if self._expected is None: |
| 1286 | + return self._actual is not None |
| 1287 | + elif isinstance(self._expected, (int, float)): |
| 1288 | + return not isinstance(self._actual, (int, float)) |
| 1289 | + for type_ in [bool, str, _Str, list, dict]: |
| 1290 | + if isinstance(self._expected, type_): |
| 1291 | + return not isinstance(self._actual, type_) |
| 1292 | + else: |
| 1293 | + raise TypeError("unsupported type {}".format(type(self._expected))) |
| 1294 | + |
| 1295 | + def _assert_dicts_equal(self): |
| 1296 | + self._assert_all_expected_keys_in_actual_dict() |
| 1297 | + for name in self._expected: |
| 1298 | + self._assert_json_value_equals_with_item(name) |
| 1299 | + |
| 1300 | + def _assert_all_expected_keys_in_actual_dict(self): |
| 1301 | + keys = set(self._expected.keys()).difference(self._actual.keys()) |
| 1302 | + if keys: |
| 1303 | + self._raise_missing_element(keys) |
| 1304 | + |
| 1305 | + def _assert_arrays_equal(self): |
| 1306 | + if len(self._expected) != len(self._actual): |
| 1307 | + self._raise_different_sizes() |
| 1308 | + for i in range(len(self._expected)): |
| 1309 | + self._assert_json_value_equals_with_item(i) |
| 1310 | + |
| 1311 | + def _assert_json_value_equals_with_item(self, item): |
| 1312 | + path = self._path.append(item) |
| 1313 | + expected = self._expected[item] |
| 1314 | + actual = self._actual[item] |
| 1315 | + _JSONComparer(path, expected, actual).assert_() |
| 1316 | + |
| 1317 | + def _assert_fundamental_values_equal(self): |
| 1318 | + if self._expected != self._actual: |
| 1319 | + self._raise_different_values() |
| 1320 | + |
| 1321 | + def _raise_different_values(self): |
| 1322 | + self._raise_assertion_error( |
| 1323 | + "element {path} differs: {expected} != {actual}") |
| 1324 | + |
| 1325 | + def _raise_different_sizes(self): |
| 1326 | + self._raise_assertion_error( |
| 1327 | + "JSON array {path} differs in size: " |
| 1328 | + "{expected_len} != {actual_len}", |
| 1329 | + expected_len=len(self._expected), |
| 1330 | + actual_len=len(self._actual)) |
| 1331 | + |
| 1332 | + def _raise_missing_element(self, keys): |
| 1333 | + if len(keys) == 1: |
| 1334 | + format_string = "element {elements} missing from element {path}" |
| 1335 | + elements = repr(next(iter(keys))) |
| 1336 | + else: |
| 1337 | + format_string = "elements {elements} missing from element {path}" |
| 1338 | + sorted_keys = sorted(keys) |
| 1339 | + elements = (", ".join(repr(k) for k in sorted_keys[:-1]) + |
| 1340 | + ", and " + repr(sorted_keys[-1])) |
| 1341 | + self._raise_assertion_error(format_string, elements=elements) |
| 1342 | + |
| 1343 | + def _raise_assertion_error(self, format_, **kwargs): |
| 1344 | + kwargs.update({ |
| 1345 | + "path": self._path, |
| 1346 | + "expected": repr(self._expected), |
| 1347 | + "actual": repr(self._actual), |
| 1348 | + }) |
| 1349 | + raise AssertionError(format_.format(**kwargs)) |
| 1350 | + |
| 1351 | + |
| 1352 | +class _JSONPath: |
| 1353 | + def __init__(self, path): |
| 1354 | + self._path = path |
| 1355 | + |
| 1356 | + def __str__(self): |
| 1357 | + return self._path |
| 1358 | + |
| 1359 | + def append(self, item): |
| 1360 | + return _JSONPath("{0}[{1}]".format(self._path, repr(item))) |
0 commit comments