|
16 | 16 | import datetime
|
17 | 17 | import tempfile
|
18 | 18 | import time
|
| 19 | +import uuid |
19 | 20 | from binascii import unhexlify
|
20 | 21 |
|
21 | 22 | import cassandra
|
22 |
| -from cassandra import util |
| 23 | +from cassandra import util, VectorDeserializationFailure |
23 | 24 | from cassandra.cqltypes import (
|
24 | 25 | CassandraType, DateRangeType, DateType, DecimalType,
|
25 | 26 | EmptyValue, LongType, SetType, UTF8Type,
|
@@ -308,15 +309,67 @@ def test_cql_quote(self):
|
308 | 309 | self.assertEqual(cql_quote('test'), "'test'")
|
309 | 310 | self.assertEqual(cql_quote(0), '0')
|
310 | 311 |
|
311 |
| - def test_vector_round_trip(self): |
312 |
| - base = [3.4, 2.9, 41.6, 12.0] |
313 |
| - ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") |
314 |
| - base_bytes = ctype.serialize(base, 0) |
315 |
| - self.assertEqual(16, len(base_bytes)) |
316 |
| - result = ctype.deserialize(base_bytes, 0) |
317 |
| - self.assertEqual(len(base), len(result)) |
318 |
| - for idx in range(0,len(base)): |
319 |
| - self.assertAlmostEqual(base[idx], result[idx], places=5) |
| 312 | + def test_vector_round_trip_types_with_serialized_size(self): |
| 313 | + # Test all the types which specify a serialized size... see PYTHON-1371 for details |
| 314 | + self._round_trip_test([True, False, False, True], \ |
| 315 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.BooleanType, 4)") |
| 316 | + self._round_trip_test([3.4, 2.9, 41.6, 12.0], \ |
| 317 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") |
| 318 | + self._round_trip_test([3.4, 2.9, 41.6, 12.0], \ |
| 319 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DoubleType, 4)") |
| 320 | + self._round_trip_test([3, 2, 41, 12], \ |
| 321 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.LongType, 4)") |
| 322 | + self._round_trip_test([3, 2, 41, 12], \ |
| 323 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 4)") |
| 324 | + self._round_trip_test([uuid.uuid1(), uuid.uuid1(), uuid.uuid1(), uuid.uuid1()], \ |
| 325 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeUUIDType, 4)") |
| 326 | + self._round_trip_test([3, 2, 41, 12], \ |
| 327 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ShortType, 4)") |
| 328 | + self._round_trip_test([datetime.time(1,1,1), datetime.time(2,2,2), datetime.time(3,3,3)], \ |
| 329 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeType, 3)") |
| 330 | + |
| 331 | + def test_vector_round_trip_types_without_serialized_size(self): |
| 332 | + # Test all the types which do not specify a serialized size... see PYTHON-1371 for details |
| 333 | + # Varints |
| 334 | + with self.assertRaises(VectorDeserializationFailure): |
| 335 | + self._round_trip_test([3, 2, 41, 12], \ |
| 336 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") |
| 337 | + # ASCII text |
| 338 | + with self.assertRaises(VectorDeserializationFailure): |
| 339 | + self._round_trip_test(["abc", "def", "ghi", "jkl"], \ |
| 340 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.AsciiType, 4)") |
| 341 | + # UTF8 text |
| 342 | + with self.assertRaises(VectorDeserializationFailure): |
| 343 | + self._round_trip_test(["abc", "def", "ghi", "jkl"], \ |
| 344 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 4)") |
| 345 | + # Duration (containts varints) |
| 346 | + with self.assertRaises(VectorDeserializationFailure): |
| 347 | + self._round_trip_test([util.Duration(1,1,1), util.Duration(2,2,2), util.Duration(3,3,3)], \ |
| 348 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DurationType, 3)") |
| 349 | + # List (of otherwise serializable type) |
| 350 | + with self.assertRaises(VectorDeserializationFailure): |
| 351 | + self._round_trip_test([[3.4], [2.9], [41.6], [12.0]], \ |
| 352 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType(org.apache.cassandra.db.marshal.FloatType), 4)") |
| 353 | + # Set (of otherwise serializable type) |
| 354 | + with self.assertRaises(VectorDeserializationFailure): |
| 355 | + self._round_trip_test([set([3.4]), set([2.9]), set([41.6]), set([12.0])], \ |
| 356 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.FloatType), 4)") |
| 357 | + # Map (of otherwise serializable types) |
| 358 | + with self.assertRaises(VectorDeserializationFailure): |
| 359 | + self._round_trip_test([{1:3.4}, {2:2.9}, {3:41.6}, {4:12.0}], \ |
| 360 | + "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \ |
| 361 | + (org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.FloatType), 4)") |
| 362 | + |
| 363 | + def _round_trip_test(self, data, ctype_str): |
| 364 | + ctype = parse_casstype_args(ctype_str) |
| 365 | + data_bytes = ctype.serialize(data, 0) |
| 366 | + serialized_size = getattr(ctype.subtype, "serial_size", None) |
| 367 | + if serialized_size: |
| 368 | + self.assertEqual(serialized_size * len(data), len(data_bytes)) |
| 369 | + result = ctype.deserialize(data_bytes, 0) |
| 370 | + self.assertEqual(len(data), len(result)) |
| 371 | + for idx in range(0,len(data)): |
| 372 | + self.assertAlmostEqual(data[idx], result[idx], places=5) |
320 | 373 |
|
321 | 374 | def test_vector_cql_parameterized_type(self):
|
322 | 375 | ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
|
|
0 commit comments