Skip to content

Commit 6b7d304

Browse files
absurdfarcedkropachev
authored andcommitted
PYTHON-1371 Add explicit exception type for serialization failures (datastax#1193)
1 parent 7426cbc commit 6b7d304

File tree

3 files changed

+83
-13
lines changed

3 files changed

+83
-13
lines changed

cassandra/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -764,3 +764,9 @@ def __init__(self, msg, excs=[]):
764764
if excs:
765765
complete_msg += ("The following exceptions were observed: \n" + '\n'.join(str(e) for e in excs))
766766
Exception.__init__(self, complete_msg)
767+
768+
class VectorDeserializationFailure(DriverException):
769+
"""
770+
The driver was unable to deserialize a given vector
771+
"""
772+
pass

cassandra/cqltypes.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
float_pack, float_unpack, double_pack, double_unpack,
5050
varint_pack, varint_unpack, point_be, point_le,
5151
vints_pack, vints_unpack)
52-
from cassandra import util
52+
from cassandra import util, VectorDeserializationFailure
5353

5454
_little_endian_flag = 1 # we always serialize LE
5555
import ipaddress
@@ -460,6 +460,7 @@ def serialize(uuid, protocol_version):
460460

461461
class BooleanType(_CassandraType):
462462
typename = 'boolean'
463+
serial_size = 1
463464

464465
@staticmethod
465466
def deserialize(byts, protocol_version):
@@ -499,6 +500,7 @@ def serialize(var, protocol_version):
499500

500501
class FloatType(_CassandraType):
501502
typename = 'float'
503+
serial_size = 4
502504

503505
@staticmethod
504506
def deserialize(byts, protocol_version):
@@ -511,6 +513,7 @@ def serialize(byts, protocol_version):
511513

512514
class DoubleType(_CassandraType):
513515
typename = 'double'
516+
serial_size = 8
514517

515518
@staticmethod
516519
def deserialize(byts, protocol_version):
@@ -523,6 +526,7 @@ def serialize(byts, protocol_version):
523526

524527
class LongType(_CassandraType):
525528
typename = 'bigint'
529+
serial_size = 8
526530

527531
@staticmethod
528532
def deserialize(byts, protocol_version):
@@ -535,6 +539,7 @@ def serialize(byts, protocol_version):
535539

536540
class Int32Type(_CassandraType):
537541
typename = 'int'
542+
serial_size = 4
538543

539544
@staticmethod
540545
def deserialize(byts, protocol_version):
@@ -647,6 +652,7 @@ class TimestampType(DateType):
647652

648653
class TimeUUIDType(DateType):
649654
typename = 'timeuuid'
655+
serial_size = 16
650656

651657
def my_timestamp(self):
652658
return util.unix_time_from_uuid1(self.val)
@@ -693,6 +699,7 @@ def serialize(val, protocol_version):
693699

694700
class ShortType(_CassandraType):
695701
typename = 'smallint'
702+
serial_size = 2
696703

697704
@staticmethod
698705
def deserialize(byts, protocol_version):
@@ -705,6 +712,7 @@ def serialize(byts, protocol_version):
705712

706713
class TimeType(_CassandraType):
707714
typename = 'time'
715+
serial_size = 8
708716

709717
@staticmethod
710718
def deserialize(byts, protocol_version):
@@ -1419,8 +1427,11 @@ def apply_parameters(cls, params, names):
14191427

14201428
@classmethod
14211429
def deserialize(cls, byts, protocol_version):
1422-
indexes = (4 * x for x in range(0, cls.vector_size))
1423-
return [cls.subtype.deserialize(byts[idx:idx + 4], protocol_version) for idx in indexes]
1430+
serialized_size = getattr(cls.subtype, "serial_size", None)
1431+
if not serialized_size:
1432+
raise VectorDeserializationFailure("Cannot determine serialized size for vector with subtype %s" % cls.subtype.__name__)
1433+
indexes = (serialized_size * x for x in range(0, cls.vector_size))
1434+
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
14241435

14251436
@classmethod
14261437
def serialize(cls, v, protocol_version):

tests/unit/test_types.py

+63-10
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import datetime
1717
import tempfile
1818
import time
19+
import uuid
1920
from binascii import unhexlify
2021

2122
import cassandra
22-
from cassandra import util
23+
from cassandra import util, VectorDeserializationFailure
2324
from cassandra.cqltypes import (
2425
CassandraType, DateRangeType, DateType, DecimalType,
2526
EmptyValue, LongType, SetType, UTF8Type,
@@ -308,15 +309,67 @@ def test_cql_quote(self):
308309
self.assertEqual(cql_quote('test'), "'test'")
309310
self.assertEqual(cql_quote(0), '0')
310311

311-
def test_vector_round_trip(self):
312-
base = [3.4, 2.9, 41.6, 12.0]
313-
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
314-
base_bytes = ctype.serialize(base, 0)
315-
self.assertEqual(16, len(base_bytes))
316-
result = ctype.deserialize(base_bytes, 0)
317-
self.assertEqual(len(base), len(result))
318-
for idx in range(0,len(base)):
319-
self.assertAlmostEqual(base[idx], result[idx], places=5)
312+
def test_vector_round_trip_types_with_serialized_size(self):
313+
# Test all the types which specify a serialized size... see PYTHON-1371 for details
314+
self._round_trip_test([True, False, False, True], \
315+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.BooleanType, 4)")
316+
self._round_trip_test([3.4, 2.9, 41.6, 12.0], \
317+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
318+
self._round_trip_test([3.4, 2.9, 41.6, 12.0], \
319+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DoubleType, 4)")
320+
self._round_trip_test([3, 2, 41, 12], \
321+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.LongType, 4)")
322+
self._round_trip_test([3, 2, 41, 12], \
323+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.Int32Type, 4)")
324+
self._round_trip_test([uuid.uuid1(), uuid.uuid1(), uuid.uuid1(), uuid.uuid1()], \
325+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeUUIDType, 4)")
326+
self._round_trip_test([3, 2, 41, 12], \
327+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ShortType, 4)")
328+
self._round_trip_test([datetime.time(1,1,1), datetime.time(2,2,2), datetime.time(3,3,3)], \
329+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.TimeType, 3)")
330+
331+
def test_vector_round_trip_types_without_serialized_size(self):
332+
# Test all the types which do not specify a serialized size... see PYTHON-1371 for details
333+
# Varints
334+
with self.assertRaises(VectorDeserializationFailure):
335+
self._round_trip_test([3, 2, 41, 12], \
336+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)")
337+
# ASCII text
338+
with self.assertRaises(VectorDeserializationFailure):
339+
self._round_trip_test(["abc", "def", "ghi", "jkl"], \
340+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.AsciiType, 4)")
341+
# UTF8 text
342+
with self.assertRaises(VectorDeserializationFailure):
343+
self._round_trip_test(["abc", "def", "ghi", "jkl"], \
344+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.UTF8Type, 4)")
345+
# Duration (containts varints)
346+
with self.assertRaises(VectorDeserializationFailure):
347+
self._round_trip_test([util.Duration(1,1,1), util.Duration(2,2,2), util.Duration(3,3,3)], \
348+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.DurationType, 3)")
349+
# List (of otherwise serializable type)
350+
with self.assertRaises(VectorDeserializationFailure):
351+
self._round_trip_test([[3.4], [2.9], [41.6], [12.0]], \
352+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.ListType(org.apache.cassandra.db.marshal.FloatType), 4)")
353+
# Set (of otherwise serializable type)
354+
with self.assertRaises(VectorDeserializationFailure):
355+
self._round_trip_test([set([3.4]), set([2.9]), set([41.6]), set([12.0])], \
356+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.FloatType), 4)")
357+
# Map (of otherwise serializable types)
358+
with self.assertRaises(VectorDeserializationFailure):
359+
self._round_trip_test([{1:3.4}, {2:2.9}, {3:41.6}, {4:12.0}], \
360+
"org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.MapType \
361+
(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.FloatType), 4)")
362+
363+
def _round_trip_test(self, data, ctype_str):
364+
ctype = parse_casstype_args(ctype_str)
365+
data_bytes = ctype.serialize(data, 0)
366+
serialized_size = getattr(ctype.subtype, "serial_size", None)
367+
if serialized_size:
368+
self.assertEqual(serialized_size * len(data), len(data_bytes))
369+
result = ctype.deserialize(data_bytes, 0)
370+
self.assertEqual(len(data), len(result))
371+
for idx in range(0,len(data)):
372+
self.assertAlmostEqual(data[idx], result[idx], places=5)
320373

321374
def test_vector_cql_parameterized_type(self):
322375
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")

0 commit comments

Comments
 (0)