Skip to content

Commit ef22e7e

Browse files
committed
Add ZstdSerializer, change Key to include "domain"
1 parent 87661ea commit ef22e7e

11 files changed

+703
-108
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ application-import-names = meta_memcache,tests
77
import-order-style = cryptography
88
per-file-ignores =
99
__init__.py:F401
10-
tests/*:S101,S403
10+
tests/*:S101,S403

README.md

+16-4
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,26 @@ will be gone or present, since they are stored in the same server). Note this is
9797
also risky, if you place all keys of a user in the same server, and the server
9898
goes down, the user life will be miserable.
9999

100-
### Unicode keys:
101-
Unicode keys are supported, the keys will be hashed according to Meta commands
100+
### Custom domains:
101+
You can add a domain to keys. This domain can be used for custom per-domain
102+
metrics like hit ratios or to control serialization of the values.
103+
```python:
104+
Key("key:1:2", domain="example")
105+
```
106+
For example the ZstdSerializer allows to configure different dictionaries by
107+
domain, so you can compress more efficiently data of different domains.
108+
109+
### Unicode/binary keys:
110+
Both unicode and binary keys are supported, the keys will be hashed/encoded according to Meta commands
102111
[binary encoded keys](https://github.com/memcached/memcached/wiki/MetaCommands#binary-encoded-keys)
103112
specification.
104113

105-
To use this, mark the key as unicode:
114+
Using binary keys can have benefits, saving space in memory. While over the wire the key
115+
is transmited b64 encoded, the memcache server will use the byte representation, so it will
116+
not have the 1/4 overhead of b64 encoding.
117+
106118
```python:
107-
Key("🍺", unicode=True)
119+
Key("🍺")
108120
```
109121

110122
### Large keys:

poetry.lock

+184-50
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ packages = [{include = "meta_memcache", from="src"}]
1414
python = "^3.8"
1515
uhashring = "^2.1"
1616
marisa-trie = "^1.0.0"
17-
meta-memcache-socket = "0.1.1"
17+
meta-memcache-socket = "0.1.3"
18+
zstandard = "^0.22.0"
1819

1920
[tool.poetry.group.extras.dependencies]
2021
prometheus-client = "^0.17.1"
@@ -27,7 +28,7 @@ testpaths = [
2728

2829
[tool.isort]
2930
profile = "black"
30-
known_third_party = ["uhashring", "pytest", "pytest_mock", "marisa-trie"]
31+
known_third_party = ["uhashring", "pytest", "pytest_mock", "marisa-trie", "zstandard"]
3132

3233
[tool.coverage.paths]
3334
source = ["src", "*/site-packages"]

src/meta_memcache/base/base_serializer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from typing import Any, NamedTuple
33

4-
from meta_memcache.protocol import Blob
4+
from meta_memcache.protocol import Blob, Key
55

66

77
class EncodedValue(NamedTuple):
@@ -13,6 +13,7 @@ class BaseSerializer(ABC):
1313
@abstractmethod
1414
def serialize(
1515
self,
16+
key: Key,
1617
value: Any,
1718
) -> EncodedValue: ...
1819

src/meta_memcache/executors/default.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,11 @@ def _build_cmd(
6868

6969
def _prepare_serialized_value_and_flags(
7070
self,
71+
key: Key,
7172
value: ValueContainer,
7273
flags: Optional[RequestFlags],
7374
) -> Tuple[Optional[bytes], RequestFlags]:
74-
encoded_value = self._serializer.serialize(value.value)
75+
encoded_value = self._serializer.serialize(key, value.value)
7576
flags = flags if flags is not None else RequestFlags()
7677
flags.client_flag = encoded_value.encoding_id
7778
return encoded_value.data, flags
@@ -106,7 +107,7 @@ def exec_on_pool(
106107
cmd_value, flags = (
107108
(None, flags)
108109
if value is None
109-
else self._prepare_serialized_value_and_flags(value, flags)
110+
else self._prepare_serialized_value_and_flags(key, value, flags)
110111
)
111112
try:
112113
conn = pool.pop_connection()
@@ -159,7 +160,7 @@ def exec_multi_on_pool( # noqa: C901
159160
cmd_value, flags = (
160161
(None, flags)
161162
if value is None
162-
else self._prepare_serialized_value_and_flags(value, flags)
163+
else self._prepare_serialized_value_and_flags(key, value, flags)
163164
)
164165

165166
self._conn_send_cmd(

src/meta_memcache/protocol.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,23 @@
2222

2323
@dataclass
2424
class Key:
25-
__slots__ = ("key", "routing_key", "is_unicode")
25+
__slots__ = ("key", "routing_key", "domain", "disable_compression")
2626
key: str
2727
routing_key: Optional[str]
28-
is_unicode: bool
28+
domain: Optional[str]
29+
disable_compression: bool
2930

3031
def __init__(
3132
self,
3233
key: str,
3334
routing_key: Optional[str] = None,
34-
is_unicode: bool = False,
35+
domain: Optional[str] = None,
36+
disabled_compression: bool = False,
3537
) -> None:
3638
self.key = key
3739
self.routing_key = routing_key
38-
self.is_unicode = is_unicode
40+
self.domain = domain
41+
self.disable_compression = disabled_compression
3942

4043
def __hash__(self) -> int:
4144
return hash((self.key, self.routing_key))

src/meta_memcache/serializer.py

+177-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import pickle # noqa: S403
22
import zlib
3-
from typing import Any
3+
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
44

55
from meta_memcache.base.base_serializer import BaseSerializer, EncodedValue
6-
from meta_memcache.protocol import Blob
6+
from meta_memcache.protocol import Blob, Key
7+
import zstandard as zstd
78

89

910
class MixedSerializer(BaseSerializer):
@@ -13,13 +14,15 @@ class MixedSerializer(BaseSerializer):
1314
LONG = 4
1415
ZLIB_COMPRESSED = 8
1516
BINARY = 16
17+
1618
COMPRESSION_THRESHOLD = 128
1719

1820
def __init__(self, pickle_protocol: int = 0) -> None:
1921
self._pickle_protocol = pickle_protocol
2022

2123
def serialize(
2224
self,
25+
key: Key,
2326
value: Any,
2427
) -> EncodedValue:
2528
if isinstance(value, bytes):
@@ -53,3 +56,175 @@ def unserialize(self, data: Blob, encoding_id: int) -> Any:
5356
return bytes(data)
5457
else:
5558
return pickle.loads(data) # noqa: S301
59+
60+
61+
class DictionaryMapping(NamedTuple):
62+
dictionary: bytes
63+
active_domains: List[str]
64+
65+
66+
class ZstdSerializer(BaseSerializer):
67+
STR = 0
68+
PICKLE = 1
69+
INT = 2
70+
LONG = 4
71+
ZLIB_COMPRESSED = 8
72+
BINARY = 16
73+
ZSTD_COMPRESSED = 32
74+
75+
ZSTD_MAGIC = b"(\xb5/\xfd"
76+
DEFAULT_PICKLE_PROTOCOL = 5
77+
DEFAULT_COMPRESSION_LEVEL = 9
78+
DEFAULT_COMPRESSION_THRESHOLD = 128
79+
DEFAULT_DICT_COMPRESSION_THRESHOLD = 64
80+
81+
_pickle_protocol: int
82+
_compression_level: int
83+
_default_compression_threshold: int
84+
_dict_compression_threshold: int
85+
_zstd_compressors: Dict[int, zstd.ZstdCompressor]
86+
_zstd_decompressors: Dict[int, zstd.ZstdDecompressor]
87+
_domain_to_dict_id: Dict[str, int]
88+
_default_zstd_compressor: Optional[zstd.ZstdCompressor]
89+
90+
def __init__(
91+
self,
92+
pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL,
93+
compression_level: int = DEFAULT_COMPRESSION_LEVEL,
94+
compression_threshold: int = DEFAULT_COMPRESSION_THRESHOLD,
95+
dict_compression_threshold: int = DEFAULT_DICT_COMPRESSION_THRESHOLD,
96+
dictionary_mappings: Optional[List[DictionaryMapping]] = None,
97+
default_dictionary: Optional[bytes] = None,
98+
default_zstd: bool = True,
99+
) -> None:
100+
self._pickle_protocol = pickle_protocol
101+
self._compression_level = compression_level
102+
self._default_compression_threshold = (
103+
compression_threshold
104+
if not default_dictionary
105+
else dict_compression_threshold
106+
)
107+
self._dict_compression_threshold = dict_compression_threshold
108+
self._zstd_compressors = {}
109+
self._zstd_decompressors = {}
110+
self._domain_to_dict_id = {}
111+
112+
compression_params = zstd.ZstdCompressionParameters.from_level(
113+
compression_level,
114+
format=zstd.FORMAT_ZSTD1_MAGICLESS,
115+
write_content_size=True,
116+
write_checksum=False,
117+
write_dict_id=True,
118+
)
119+
120+
if dictionary_mappings:
121+
for dictionary_mapping in dictionary_mappings:
122+
dict_id, zstd_dict = self._build_dict(dictionary_mapping.dictionary)
123+
self._add_dict_decompressor(dict_id, zstd_dict)
124+
if dictionary_mapping.active_domains:
125+
# The dictionary is active for some domains
126+
self._add_dict_compressor(dict_id, zstd_dict, compression_params)
127+
for domain in dictionary_mapping.active_domains:
128+
self._domain_to_dict_id[domain] = dict_id
129+
130+
if default_dictionary:
131+
dict_id, zstd_dict = self._build_dict(default_dictionary)
132+
self._add_dict_decompressor(dict_id, zstd_dict)
133+
134+
self._default_zstd_compressor = self._add_dict_compressor(
135+
dict_id, zstd_dict, compression_params
136+
)
137+
elif default_zstd:
138+
self._default_zstd_compressor = zstd.ZstdCompressor(
139+
compression_params=compression_params
140+
)
141+
else:
142+
self._default_zstd_compressor = None
143+
144+
self._zstd_decompressors[0] = zstd.ZstdDecompressor()
145+
146+
def _build_dict(self, dictionary: bytes) -> Tuple[int, zstd.ZstdCompressionDict]:
147+
zstd_dict = zstd.ZstdCompressionDict(dictionary)
148+
dict_id = zstd_dict.dict_id()
149+
self._zstd_decompressors[dict_id] = zstd.ZstdDecompressor(dict_data=zstd_dict)
150+
return dict_id, zstd_dict
151+
152+
def _add_dict_decompressor(
153+
self, dict_id: int, zstd_dict: zstd.ZstdCompressionDict
154+
) -> zstd.ZstdDecompressor:
155+
self._zstd_decompressors[dict_id] = zstd.ZstdDecompressor(dict_data=zstd_dict)
156+
return self._zstd_decompressors[dict_id]
157+
158+
def _add_dict_compressor(
159+
self,
160+
dict_id: int,
161+
zstd_dict: zstd.ZstdCompressionDict,
162+
compression_params: zstd.ZstdCompressionParameters,
163+
) -> zstd.ZstdCompressor:
164+
self._zstd_compressors[dict_id] = zstd.ZstdCompressor(
165+
dict_data=zstd_dict, compression_params=compression_params
166+
)
167+
return self._zstd_compressors[dict_id]
168+
169+
def _compress(self, key: Key, data: bytes) -> Tuple[bytes, int]:
170+
if key.domain and (dict_id := self._domain_to_dict_id.get(key.domain)):
171+
return self._zstd_compressors[dict_id].compress(data), self.ZSTD_COMPRESSED
172+
elif self._default_zstd_compressor:
173+
return self._default_zstd_compressor.compress(data), self.ZSTD_COMPRESSED
174+
else:
175+
return zlib.compress(data), self.ZLIB_COMPRESSED
176+
177+
def _decompress(self, data: bytes) -> bytes:
178+
data = self.ZSTD_MAGIC + data
179+
dict_id = zstd.get_frame_parameters(data).dict_id
180+
if decompressor := self._zstd_decompressors.get(dict_id):
181+
return decompressor.decompress(data)
182+
raise ValueError(f"Unknown dictionary id: {dict_id}")
183+
184+
def _should_compress(self, key: Key, data: bytes) -> bool:
185+
data_len = len(data)
186+
if data_len >= self._default_compression_threshold:
187+
return True
188+
elif data_len >= self._dict_compression_threshold:
189+
return bool(key.domain and self._domain_to_dict_id.get(key.domain))
190+
return False
191+
192+
def serialize(
193+
self,
194+
key: Key,
195+
value: Any,
196+
) -> EncodedValue:
197+
if isinstance(value, bytes):
198+
data = value
199+
encoding_id = self.BINARY
200+
elif isinstance(value, int) and not isinstance(value, bool):
201+
data = str(value).encode("ascii")
202+
encoding_id = self.INT
203+
elif isinstance(value, str):
204+
data = str(value).encode()
205+
encoding_id = self.STR
206+
else:
207+
data = pickle.dumps(value, protocol=self._pickle_protocol)
208+
encoding_id = self.PICKLE
209+
210+
if not key.disable_compression and self._should_compress(key, data):
211+
data, compression_flag = self._compress(key, data)
212+
encoding_id |= compression_flag
213+
return EncodedValue(data=data, encoding_id=encoding_id)
214+
215+
def unserialize(self, data: Blob, encoding_id: int) -> Any:
216+
if encoding_id & self.ZLIB_COMPRESSED:
217+
data = zlib.decompress(data)
218+
encoding_id ^= self.ZLIB_COMPRESSED
219+
elif encoding_id & self.ZSTD_COMPRESSED:
220+
data = self._decompress(data)
221+
encoding_id ^= self.ZSTD_COMPRESSED
222+
223+
if encoding_id == self.STR:
224+
return bytes(data).decode()
225+
elif encoding_id in (self.INT, self.LONG):
226+
return int(data)
227+
elif encoding_id == self.BINARY:
228+
return bytes(data)
229+
else:
230+
return pickle.loads(data) # noqa: S301

0 commit comments

Comments
 (0)