1
1
import pickle # noqa: S403
2
2
import zlib
3
- from typing import Any
3
+ from typing import Any , Dict , List , NamedTuple , Optional , Tuple
4
4
5
5
from meta_memcache .base .base_serializer import BaseSerializer , EncodedValue
6
- from meta_memcache .protocol import Blob
6
+ from meta_memcache .protocol import Blob , Key
7
+ import zstandard as zstd
7
8
8
9
9
10
class MixedSerializer (BaseSerializer ):
@@ -13,13 +14,15 @@ class MixedSerializer(BaseSerializer):
13
14
LONG = 4
14
15
ZLIB_COMPRESSED = 8
15
16
BINARY = 16
17
+
16
18
COMPRESSION_THRESHOLD = 128
17
19
18
20
def __init__ (self , pickle_protocol : int = 0 ) -> None :
19
21
self ._pickle_protocol = pickle_protocol
20
22
21
23
def serialize (
22
24
self ,
25
+ key : Key ,
23
26
value : Any ,
24
27
) -> EncodedValue :
25
28
if isinstance (value , bytes ):
@@ -53,3 +56,175 @@ def unserialize(self, data: Blob, encoding_id: int) -> Any:
53
56
return bytes (data )
54
57
else :
55
58
return pickle .loads (data ) # noqa: S301
59
+
60
+
61
+ class DictionaryMapping (NamedTuple ):
62
+ dictionary : bytes
63
+ active_domains : List [str ]
64
+
65
+
66
+ class ZstdSerializer (BaseSerializer ):
67
+ STR = 0
68
+ PICKLE = 1
69
+ INT = 2
70
+ LONG = 4
71
+ ZLIB_COMPRESSED = 8
72
+ BINARY = 16
73
+ ZSTD_COMPRESSED = 32
74
+
75
+ ZSTD_MAGIC = b"(\xb5 /\xfd "
76
+ DEFAULT_PICKLE_PROTOCOL = 5
77
+ DEFAULT_COMPRESSION_LEVEL = 9
78
+ DEFAULT_COMPRESSION_THRESHOLD = 128
79
+ DEFAULT_DICT_COMPRESSION_THRESHOLD = 64
80
+
81
+ _pickle_protocol : int
82
+ _compression_level : int
83
+ _default_compression_threshold : int
84
+ _dict_compression_threshold : int
85
+ _zstd_compressors : Dict [int , zstd .ZstdCompressor ]
86
+ _zstd_decompressors : Dict [int , zstd .ZstdDecompressor ]
87
+ _domain_to_dict_id : Dict [str , int ]
88
+ _default_zstd_compressor : Optional [zstd .ZstdCompressor ]
89
+
90
+ def __init__ (
91
+ self ,
92
+ pickle_protocol : int = DEFAULT_PICKLE_PROTOCOL ,
93
+ compression_level : int = DEFAULT_COMPRESSION_LEVEL ,
94
+ compression_threshold : int = DEFAULT_COMPRESSION_THRESHOLD ,
95
+ dict_compression_threshold : int = DEFAULT_DICT_COMPRESSION_THRESHOLD ,
96
+ dictionary_mappings : Optional [List [DictionaryMapping ]] = None ,
97
+ default_dictionary : Optional [bytes ] = None ,
98
+ default_zstd : bool = True ,
99
+ ) -> None :
100
+ self ._pickle_protocol = pickle_protocol
101
+ self ._compression_level = compression_level
102
+ self ._default_compression_threshold = (
103
+ compression_threshold
104
+ if not default_dictionary
105
+ else dict_compression_threshold
106
+ )
107
+ self ._dict_compression_threshold = dict_compression_threshold
108
+ self ._zstd_compressors = {}
109
+ self ._zstd_decompressors = {}
110
+ self ._domain_to_dict_id = {}
111
+
112
+ compression_params = zstd .ZstdCompressionParameters .from_level (
113
+ compression_level ,
114
+ format = zstd .FORMAT_ZSTD1_MAGICLESS ,
115
+ write_content_size = True ,
116
+ write_checksum = False ,
117
+ write_dict_id = True ,
118
+ )
119
+
120
+ if dictionary_mappings :
121
+ for dictionary_mapping in dictionary_mappings :
122
+ dict_id , zstd_dict = self ._build_dict (dictionary_mapping .dictionary )
123
+ self ._add_dict_decompressor (dict_id , zstd_dict )
124
+ if dictionary_mapping .active_domains :
125
+ # The dictionary is active for some domains
126
+ self ._add_dict_compressor (dict_id , zstd_dict , compression_params )
127
+ for domain in dictionary_mapping .active_domains :
128
+ self ._domain_to_dict_id [domain ] = dict_id
129
+
130
+ if default_dictionary :
131
+ dict_id , zstd_dict = self ._build_dict (default_dictionary )
132
+ self ._add_dict_decompressor (dict_id , zstd_dict )
133
+
134
+ self ._default_zstd_compressor = self ._add_dict_compressor (
135
+ dict_id , zstd_dict , compression_params
136
+ )
137
+ elif default_zstd :
138
+ self ._default_zstd_compressor = zstd .ZstdCompressor (
139
+ compression_params = compression_params
140
+ )
141
+ else :
142
+ self ._default_zstd_compressor = None
143
+
144
+ self ._zstd_decompressors [0 ] = zstd .ZstdDecompressor ()
145
+
146
+ def _build_dict (self , dictionary : bytes ) -> Tuple [int , zstd .ZstdCompressionDict ]:
147
+ zstd_dict = zstd .ZstdCompressionDict (dictionary )
148
+ dict_id = zstd_dict .dict_id ()
149
+ self ._zstd_decompressors [dict_id ] = zstd .ZstdDecompressor (dict_data = zstd_dict )
150
+ return dict_id , zstd_dict
151
+
152
+ def _add_dict_decompressor (
153
+ self , dict_id : int , zstd_dict : zstd .ZstdCompressionDict
154
+ ) -> zstd .ZstdDecompressor :
155
+ self ._zstd_decompressors [dict_id ] = zstd .ZstdDecompressor (dict_data = zstd_dict )
156
+ return self ._zstd_decompressors [dict_id ]
157
+
158
+ def _add_dict_compressor (
159
+ self ,
160
+ dict_id : int ,
161
+ zstd_dict : zstd .ZstdCompressionDict ,
162
+ compression_params : zstd .ZstdCompressionParameters ,
163
+ ) -> zstd .ZstdCompressor :
164
+ self ._zstd_compressors [dict_id ] = zstd .ZstdCompressor (
165
+ dict_data = zstd_dict , compression_params = compression_params
166
+ )
167
+ return self ._zstd_compressors [dict_id ]
168
+
169
+ def _compress (self , key : Key , data : bytes ) -> Tuple [bytes , int ]:
170
+ if key .domain and (dict_id := self ._domain_to_dict_id .get (key .domain )):
171
+ return self ._zstd_compressors [dict_id ].compress (data ), self .ZSTD_COMPRESSED
172
+ elif self ._default_zstd_compressor :
173
+ return self ._default_zstd_compressor .compress (data ), self .ZSTD_COMPRESSED
174
+ else :
175
+ return zlib .compress (data ), self .ZLIB_COMPRESSED
176
+
177
+ def _decompress (self , data : bytes ) -> bytes :
178
+ data = self .ZSTD_MAGIC + data
179
+ dict_id = zstd .get_frame_parameters (data ).dict_id
180
+ if decompressor := self ._zstd_decompressors .get (dict_id ):
181
+ return decompressor .decompress (data )
182
+ raise ValueError (f"Unknown dictionary id: { dict_id } " )
183
+
184
+ def _should_compress (self , key : Key , data : bytes ) -> bool :
185
+ data_len = len (data )
186
+ if data_len >= self ._default_compression_threshold :
187
+ return True
188
+ elif data_len >= self ._dict_compression_threshold :
189
+ return bool (key .domain and self ._domain_to_dict_id .get (key .domain ))
190
+ return False
191
+
192
+ def serialize (
193
+ self ,
194
+ key : Key ,
195
+ value : Any ,
196
+ ) -> EncodedValue :
197
+ if isinstance (value , bytes ):
198
+ data = value
199
+ encoding_id = self .BINARY
200
+ elif isinstance (value , int ) and not isinstance (value , bool ):
201
+ data = str (value ).encode ("ascii" )
202
+ encoding_id = self .INT
203
+ elif isinstance (value , str ):
204
+ data = str (value ).encode ()
205
+ encoding_id = self .STR
206
+ else :
207
+ data = pickle .dumps (value , protocol = self ._pickle_protocol )
208
+ encoding_id = self .PICKLE
209
+
210
+ if not key .disable_compression and self ._should_compress (key , data ):
211
+ data , compression_flag = self ._compress (key , data )
212
+ encoding_id |= compression_flag
213
+ return EncodedValue (data = data , encoding_id = encoding_id )
214
+
215
+ def unserialize (self , data : Blob , encoding_id : int ) -> Any :
216
+ if encoding_id & self .ZLIB_COMPRESSED :
217
+ data = zlib .decompress (data )
218
+ encoding_id ^= self .ZLIB_COMPRESSED
219
+ elif encoding_id & self .ZSTD_COMPRESSED :
220
+ data = self ._decompress (data )
221
+ encoding_id ^= self .ZSTD_COMPRESSED
222
+
223
+ if encoding_id == self .STR :
224
+ return bytes (data ).decode ()
225
+ elif encoding_id in (self .INT , self .LONG ):
226
+ return int (data )
227
+ elif encoding_id == self .BINARY :
228
+ return bytes (data )
229
+ else :
230
+ return pickle .loads (data ) # noqa: S301
0 commit comments