1
1
import pickle # noqa: S403
2
2
import zlib
3
- from typing import Any
3
+ from typing import Any , Dict , List , NamedTuple , Optional , Tuple
4
4
5
5
from meta_memcache .base .base_serializer import BaseSerializer , EncodedValue
6
- from meta_memcache .protocol import Blob
6
+ from meta_memcache .protocol import Blob , Key
7
+ import zstandard as zstd
7
8
8
9
9
10
class MixedSerializer (BaseSerializer ):
@@ -13,13 +14,15 @@ class MixedSerializer(BaseSerializer):
13
14
LONG = 4
14
15
ZLIB_COMPRESSED = 8
15
16
BINARY = 16
17
+
16
18
COMPRESSION_THRESHOLD = 128
17
19
18
20
def __init__ (self , pickle_protocol : int = 0 ) -> None :
19
21
self ._pickle_protocol = pickle_protocol
20
22
21
23
def serialize (
22
24
self ,
25
+ key : Key ,
23
26
value : Any ,
24
27
) -> EncodedValue :
25
28
if isinstance (value , bytes ):
@@ -53,3 +56,174 @@ def unserialize(self, data: Blob, encoding_id: int) -> Any:
53
56
return bytes (data )
54
57
else :
55
58
return pickle .loads (data ) # noqa: S301
59
+
60
+
61
+ class DictionaryMapping (NamedTuple ):
62
+ dictionary : bytes
63
+ active_domains : List [str ]
64
+
65
+
66
+ class ZstdSerializer (BaseSerializer ):
67
+ STR = 0
68
+ PICKLE = 1
69
+ INT = 2
70
+ LONG = 4
71
+ ZLIB_COMPRESSED = 8
72
+ BINARY = 16
73
+ ZSTD_COMPRESSED = 32
74
+
75
+ ZSTD_MAGIC = b"(\xb5 /\xfd "
76
+ DEFAULT_PICKLE_PROTOCOL = 5
77
+ DEFAULT_COMPRESSION_LEVEL = 9
78
+ DEFAULT_COMPRESSION_THRESHOLD = 128
79
+ DEFAULT_DICT_COMPRESSION_THRESHOLD = 64
80
+
81
+ _pickle_protocol : int
82
+ _compression_level : int
83
+ _default_compression_threshold : int
84
+ _dict_compression_threshold : int
85
+ _zstd_compressors : Dict [int , zstd .ZstdCompressor ]
86
+ _zstd_decompressors : Dict [int , zstd .ZstdDecompressor ]
87
+ _domain_to_dict_id : Dict [str , int ]
88
+ _default_zstd_compressor : Optional [zstd .ZstdCompressor ]
89
+
90
+ def __init__ (
91
+ self ,
92
+ pickle_protocol : int = DEFAULT_PICKLE_PROTOCOL ,
93
+ compression_level : int = DEFAULT_COMPRESSION_LEVEL ,
94
+ compression_threshold : int = DEFAULT_COMPRESSION_THRESHOLD ,
95
+ dict_compression_threshold : int = DEFAULT_DICT_COMPRESSION_THRESHOLD ,
96
+ dictionary_mappings : Optional [List [DictionaryMapping ]] = None ,
97
+ default_dictionary : Optional [bytes ] = None ,
98
+ default_zstd : bool = True ,
99
+ ) -> None :
100
+ self ._pickle_protocol = pickle_protocol
101
+ self ._compression_level = compression_level
102
+ self ._default_compression_threshold = (
103
+ compression_threshold
104
+ if not default_dictionary
105
+ else dict_compression_threshold
106
+ )
107
+ self ._dict_compression_threshold = dict_compression_threshold
108
+ self ._zstd_compressors = {}
109
+ self ._zstd_decompressors = {}
110
+ self ._domain_to_dict_id = {}
111
+
112
+ compression_params = zstd .ZstdCompressionParameters .from_level (
113
+ compression_level ,
114
+ format = zstd .FORMAT_ZSTD1_MAGICLESS ,
115
+ write_content_size = True ,
116
+ write_checksum = False ,
117
+ write_dict_id = True ,
118
+ )
119
+
120
+ if dictionary_mappings :
121
+ for dictionary_mapping in dictionary_mappings :
122
+ dict_id , zstd_dict = self ._build_dict (dictionary_mapping .dictionary )
123
+ self ._add_dict_decompressor (dict_id , zstd_dict )
124
+ if dictionary_mapping .active_domains :
125
+ # The dictionary is active for some domains
126
+ self ._add_dict_compressor (dict_id , zstd_dict , compression_params )
127
+ for domain in dictionary_mapping .active_domains :
128
+ self ._domain_to_dict_id [domain ] = dict_id
129
+
130
+ if default_dictionary :
131
+ dict_id , zstd_dict = self ._build_dict (default_dictionary )
132
+ self ._add_dict_decompressor (dict_id , zstd_dict )
133
+
134
+ self ._default_zstd_compressor = self ._add_dict_compressor (
135
+ dict_id , zstd_dict , compression_params
136
+ )
137
+ elif default_zstd :
138
+ self ._default_zstd_compressor = zstd .ZstdCompressor (
139
+ compression_params = compression_params
140
+ )
141
+ else :
142
+ self ._default_zstd_compressor = None
143
+
144
+ self ._zstd_decompressors [0 ] = zstd .ZstdDecompressor ()
145
+
146
+ def _build_dict (self , dictionary : bytes ) -> Tuple [int , zstd .ZstdCompressionDict ]:
147
+ zstd_dict = zstd .ZstdCompressionDict (dictionary )
148
+ dict_id = zstd_dict .dict_id ()
149
+ return dict_id , zstd_dict
150
+
151
+ def _add_dict_decompressor (
152
+ self , dict_id : int , zstd_dict : zstd .ZstdCompressionDict
153
+ ) -> zstd .ZstdDecompressor :
154
+ self ._zstd_decompressors [dict_id ] = zstd .ZstdDecompressor (dict_data = zstd_dict )
155
+ return self ._zstd_decompressors [dict_id ]
156
+
157
+ def _add_dict_compressor (
158
+ self ,
159
+ dict_id : int ,
160
+ zstd_dict : zstd .ZstdCompressionDict ,
161
+ compression_params : zstd .ZstdCompressionParameters ,
162
+ ) -> zstd .ZstdCompressor :
163
+ self ._zstd_compressors [dict_id ] = zstd .ZstdCompressor (
164
+ dict_data = zstd_dict , compression_params = compression_params
165
+ )
166
+ return self ._zstd_compressors [dict_id ]
167
+
168
+ def _compress (self , key : Key , data : bytes ) -> Tuple [bytes , int ]:
169
+ if key .domain and (dict_id := self ._domain_to_dict_id .get (key .domain )):
170
+ return self ._zstd_compressors [dict_id ].compress (data ), self .ZSTD_COMPRESSED
171
+ elif self ._default_zstd_compressor :
172
+ return self ._default_zstd_compressor .compress (data ), self .ZSTD_COMPRESSED
173
+ else :
174
+ return zlib .compress (data ), self .ZLIB_COMPRESSED
175
+
176
+ def _decompress (self , data : bytes ) -> bytes :
177
+ data = self .ZSTD_MAGIC + data
178
+ dict_id = zstd .get_frame_parameters (data ).dict_id
179
+ if decompressor := self ._zstd_decompressors .get (dict_id ):
180
+ return decompressor .decompress (data )
181
+ raise ValueError (f"Unknown dictionary id: { dict_id } " )
182
+
183
+ def _should_compress (self , key : Key , data : bytes ) -> bool :
184
+ data_len = len (data )
185
+ if data_len >= self ._default_compression_threshold :
186
+ return True
187
+ elif data_len >= self ._dict_compression_threshold :
188
+ return bool (key .domain and self ._domain_to_dict_id .get (key .domain ))
189
+ return False
190
+
191
+ def serialize (
192
+ self ,
193
+ key : Key ,
194
+ value : Any ,
195
+ ) -> EncodedValue :
196
+ if isinstance (value , bytes ):
197
+ data = value
198
+ encoding_id = self .BINARY
199
+ elif isinstance (value , int ) and not isinstance (value , bool ):
200
+ data = str (value ).encode ("ascii" )
201
+ encoding_id = self .INT
202
+ elif isinstance (value , str ):
203
+ data = str (value ).encode ()
204
+ encoding_id = self .STR
205
+ else :
206
+ data = pickle .dumps (value , protocol = self ._pickle_protocol )
207
+ encoding_id = self .PICKLE
208
+
209
+ if not key .disable_compression and self ._should_compress (key , data ):
210
+ data , compression_flag = self ._compress (key , data )
211
+ encoding_id |= compression_flag
212
+ return EncodedValue (data = data , encoding_id = encoding_id )
213
+
214
+ def unserialize (self , data : Blob , encoding_id : int ) -> Any :
215
+ if encoding_id & self .ZLIB_COMPRESSED :
216
+ data = zlib .decompress (data )
217
+ encoding_id ^= self .ZLIB_COMPRESSED
218
+ elif encoding_id & self .ZSTD_COMPRESSED :
219
+ data = self ._decompress (data )
220
+ encoding_id ^= self .ZSTD_COMPRESSED
221
+
222
+ if encoding_id == self .STR :
223
+ return bytes (data ).decode ()
224
+ elif encoding_id in (self .INT , self .LONG ):
225
+ return int (data )
226
+ elif encoding_id == self .BINARY :
227
+ return bytes (data )
228
+ else :
229
+ return pickle .loads (data ) # noqa: S301
0 commit comments