diff --git a/itsdangerous.py b/itsdangerous.py index 0d7767f..b83d7fd 100644 --- a/itsdangerous.py +++ b/itsdangerous.py @@ -45,8 +45,8 @@ class _CompactJSON(object): def loads(self, payload): return json.loads(payload) - def dumps(self, obj): - return json.dumps(obj, ensure_ascii=False, separators=(',', ':')) + def dumps(self, obj, **kwargs): + return json.dumps(obj, ensure_ascii=False, separators=(',', ':'), **kwargs) compact_json = _CompactJSON() @@ -504,7 +504,7 @@ class Serializer(object): default_signer = Signer def __init__(self, secret_key, salt=b'itsdangerous', serializer=None, - signer=None, signer_kwargs=None): + signer=None, signer_kwargs=None, serializer_kwargs=None): self.secret_key = want_bytes(secret_key) self.salt = want_bytes(salt) if serializer is None: @@ -515,6 +515,7 @@ def __init__(self, secret_key, salt=b'itsdangerous', serializer=None, signer = self.default_signer self.signer = signer self.signer_kwargs = signer_kwargs or {} + self.serializer_kwargs = serializer_kwargs or {} def load_payload(self, payload, serializer=None): """Loads the encoded object. This function raises :class:`BadPayload` @@ -541,7 +542,7 @@ def dump_payload(self, obj): bytestring. If the internal serializer is text based the value will automatically be encoded to utf-8. """ - return want_bytes(self.serializer.dumps(obj)) + return want_bytes(self.serializer.dumps(obj, **self.serializer_kwargs)) def make_signer(self, salt=None): """A method that creates a new instance of the signer to be used. @@ -664,9 +665,9 @@ class JSONWebSignatureSerializer(Serializer): default_serializer = compact_json def __init__(self, secret_key, salt=None, serializer=None, - signer=None, signer_kwargs=None, algorithm_name=None): + signer=None, signer_kwargs=None, serializer_kwargs=None, algorithm_name=None): Serializer.__init__(self, secret_key, salt, serializer, - signer, signer_kwargs) + signer, signer_kwargs, serializer_kwargs) if algorithm_name is None: algorithm_name = self.default_algorithm self.algorithm_name = algorithm_name @@ -702,8 +703,8 @@ def load_payload(self, payload, return_header=False): return payload def dump_payload(self, header, obj): - base64d_header = base64_encode(self.serializer.dumps(header)) - base64d_payload = base64_encode(self.serializer.dumps(obj)) + base64d_header = base64_encode(self.serializer.dumps(header, **self.serializer_kwargs)) + base64d_payload = base64_encode(self.serializer.dumps(obj, **self.serializer_kwargs)) return base64d_header + b'.' + base64d_payload def make_algorithm(self, algorithm_name): diff --git a/tests.py b/tests.py index 067cc3e..44e34af 100755 --- a/tests.py +++ b/tests.py @@ -120,6 +120,16 @@ def test_signer_kwargs(self): ts = s.dumps(value) self.assertEqual(s.loads(ts), u'hello') + def test_serializer_kwargs(self): + secret_key = 'predictable-key' + + s = self.make_serializer(secret_key, serializer_kwargs={'sort_keys': True}) + + ts1 = s.dumps({'c': 3, 'a': 1, 'b': 2}) + ts2 = s.dumps(dict(a=1, b=2, c=3)) + + self.assertEqual(ts1, ts2) + class TimedSerializerTestCase(SerializerTestCase): serializer_class = idmod.TimedSerializer @@ -282,6 +292,7 @@ def test_invalid_base64_does_not_fail_load_payload(self): class PickleSerializerMixin(object): def make_serializer(self, *args, **kwargs): + kwargs.pop('serializer_kwargs', '') kwargs.setdefault('serializer', pickle) return super(PickleSerializerMixin, self).make_serializer(*args, **kwargs)