Skip to content

Commit 3c8465f

Browse files
committed
Allow keys to be encoded before use.
Ported patch in #52 from @harlowja to current branch. Added tests. For the cases where the user wants to transparently encode keys (say using urllib) before they are used further allow a encoding function to be passed in that will perform these types of activities (by default it is the identity function).
1 parent 88b83c6 commit 3c8465f

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

memcache.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
162162
pload=None, pid=None,
163163
server_max_key_length=None, server_max_value_length=None,
164164
dead_retry=_DEAD_RETRY, socket_timeout=_SOCKET_TIMEOUT,
165-
cache_cas=False, flush_on_reconnect=0, check_keys=True):
165+
cache_cas=False, flush_on_reconnect=0, check_keys=True,
166+
key_encoder=None):
166167
"""Create a new Client object with the given list of servers.
167168
168169
@param servers: C{servers} is passed to L{set_servers}.
@@ -205,6 +206,10 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
205206
@param check_keys: (default True) If True, the key is checked
206207
to ensure it is the correct length and composed of the right
207208
characters.
209+
@param key_encoder: (default None) If provided a functor that will
210+
be called to encode keys before they are checked and used. It will
211+
be expected to take one parameter (the key) and return a new encoded
212+
key as a result.
208213
"""
209214
super(Client, self).__init__()
210215
self.debug = debug
@@ -226,6 +231,10 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
226231
self.persistent_load = pload
227232
self.persistent_id = pid
228233
self.server_max_key_length = server_max_key_length
234+
if key_encoder is None:
235+
def key_encoder(key):
236+
return key
237+
self.key_encoder = key_encoder
229238
if self.server_max_key_length is None:
230239
self.server_max_key_length = SERVER_MAX_KEY_LENGTH
231240
self.server_max_value_length = server_max_value_length
@@ -494,7 +503,7 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
494503
else:
495504
headers = None
496505
for key in server_keys[server]: # These are mangled keys
497-
cmd = self._encode_cmd('delete', key, headers, noreply, b'\r\n')
506+
cmd = self._encode_cmd('delete', self.key_encoder(key), headers, noreply, b'\r\n')
498507
write(cmd)
499508
try:
500509
server.send_cmds(b''.join(bigcmd))
@@ -532,7 +541,7 @@ def delete(self, key, noreply=False):
532541
reply.
533542
@rtype: int
534543
'''
535-
key = self._encode_key(key)
544+
key = self._encode_key(self.key_encoder(key))
536545
if self.do_check_key:
537546
self.check_key(key)
538547
server, key = self._get_server(key)
@@ -568,7 +577,7 @@ def touch(self, key, time=0, noreply=False):
568577
reply.
569578
@rtype: int
570579
'''
571-
key = self._encode_key(key)
580+
key = self._encode_key(self.key_encoder(key))
572581
if self.do_check_key:
573582
self.check_key(key)
574583
server, key = self._get_server(key)
@@ -622,7 +631,7 @@ def incr(self, key, delta=1, noreply=False):
622631
@return: New value after incrementing, no None for noreply or error.
623632
@rtype: int
624633
"""
625-
return self._incrdecr("incr", key, delta, noreply)
634+
return self._incrdecr("incr", self.key_encoder(key), delta, noreply)
626635

627636
def decr(self, key, delta=1, noreply=False):
628637
"""Decrement value for C{key} by C{delta}
@@ -640,7 +649,7 @@ def decr(self, key, delta=1, noreply=False):
640649
@return: New value after decrementing, or None for noreply or error.
641650
@rtype: int
642651
"""
643-
return self._incrdecr("decr", key, delta, noreply)
652+
return self._incrdecr("decr", self.key_encoder(key), delta, noreply)
644653

645654
def _incrdecr(self, cmd, key, delta, noreply=False):
646655
key = self._encode_key(key)
@@ -674,7 +683,7 @@ def add(self, key, val, time=0, min_compress_len=0, noreply=False):
674683
@return: Nonzero on success.
675684
@rtype: int
676685
'''
677-
return self._set("add", key, val, time, min_compress_len, noreply)
686+
return self._set("add", self.key_encoder(key), val, time, min_compress_len, noreply)
678687

679688
def append(self, key, val, time=0, min_compress_len=0, noreply=False):
680689
'''Append the value to the end of the existing key's value.
@@ -685,7 +694,7 @@ def append(self, key, val, time=0, min_compress_len=0, noreply=False):
685694
@return: Nonzero on success.
686695
@rtype: int
687696
'''
688-
return self._set("append", key, val, time, min_compress_len, noreply)
697+
return self._set("append", self.key_encoder(key), val, time, min_compress_len, noreply)
689698

690699
def prepend(self, key, val, time=0, min_compress_len=0, noreply=False):
691700
'''Prepend the value to the beginning of the existing key's value.
@@ -696,7 +705,7 @@ def prepend(self, key, val, time=0, min_compress_len=0, noreply=False):
696705
@return: Nonzero on success.
697706
@rtype: int
698707
'''
699-
return self._set("prepend", key, val, time, min_compress_len, noreply)
708+
return self._set("prepend", self.key_encoder(key), val, time, min_compress_len, noreply)
700709

701710
def replace(self, key, val, time=0, min_compress_len=0, noreply=False):
702711
'''Replace existing key with value.
@@ -707,7 +716,7 @@ def replace(self, key, val, time=0, min_compress_len=0, noreply=False):
707716
@return: Nonzero on success.
708717
@rtype: int
709718
'''
710-
return self._set("replace", key, val, time, min_compress_len, noreply)
719+
return self._set("replace", self.key_encoder(key), val, time, min_compress_len, noreply)
711720

712721
def set(self, key, val, time=0, min_compress_len=0, noreply=False):
713722
'''Unconditionally sets a key to a given value in the memcache.
@@ -743,7 +752,7 @@ def set(self, key, val, time=0, min_compress_len=0, noreply=False):
743752
'''
744753
if isinstance(time, timedelta):
745754
time = int(time.total_seconds())
746-
return self._set("set", key, val, time, min_compress_len, noreply)
755+
return self._set("set", self.key_encoder(key), val, time, min_compress_len, noreply)
747756

748757
def cas(self, key, val, time=0, min_compress_len=0, noreply=False):
749758
'''Check and set (CAS)
@@ -780,7 +789,7 @@ def cas(self, key, val, time=0, min_compress_len=0, noreply=False):
780789
@param noreply: optional parameter instructs the server to not
781790
send the reply.
782791
'''
783-
return self._set("cas", key, val, time, min_compress_len, noreply)
792+
return self._set("cas", self.key_encoder(key), val, time, min_compress_len, noreply)
784793

785794
def _map_and_prefix_keys(self, key_iterable, key_prefix):
786795
"""Map keys to the servers they will reside on.
@@ -807,7 +816,7 @@ def _map_and_prefix_keys(self, key_iterable, key_prefix):
807816
# Ensure call to _get_server gets a Tuple as well.
808817
serverhash, key = orig_key
809818

810-
key = self._encode_key(key)
819+
key = self._encode_key(self.key_encoder(key))
811820
if not isinstance(key, six.binary_type):
812821
# set_multi supports int / long keys.
813822
key = str(key).encode('utf8')
@@ -818,7 +827,7 @@ def _map_and_prefix_keys(self, key_iterable, key_prefix):
818827
server, key = self._get_server(
819828
(serverhash, key_prefix + key))
820829
else:
821-
key = self._encode_key(orig_key)
830+
key = self._encode_key(self.key_encoder(orig_key))
822831
if not isinstance(key, six.binary_type):
823832
# set_multi supports int / long keys.
824833
key = str(key).encode('utf8')
@@ -923,7 +932,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
923932
if store_info:
924933
flags, len_val, val = store_info
925934
headers = "%d %d %d" % (flags, time, len_val)
926-
fullcmd = self._encode_cmd('set', key, headers,
935+
fullcmd = self._encode_cmd('set', self.key_encoder(key), headers,
927936
noreply,
928937
b'\r\n', val, b'\r\n')
929938
write(fullcmd)
@@ -1121,14 +1130,14 @@ def get(self, key, default=None):
11211130
11221131
@return: The value or None.
11231132
'''
1124-
return self._get('get', key, default)
1133+
return self._get('get', self.key_encoder(key), default)
11251134

11261135
def gets(self, key):
11271136
'''Retrieves a key from the memcache. Used in conjunction with 'cas'.
11281137
11291138
@return: The value or None.
11301139
'''
1131-
return self._get('gets', key)
1140+
return self._get('gets', self.key_encoder(key))
11321141

11331142
def get_multi(self, keys, key_prefix=''):
11341143
'''Retrieves multiple keys from the memcache doing just one query.
@@ -1188,7 +1197,7 @@ def get_multi(self, keys, key_prefix=''):
11881197
self._statlog('get_multi')
11891198

11901199
server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(
1191-
keys, key_prefix)
1200+
[self.key_encoder(k) for k in keys], key_prefix)
11921201

11931202
# send out all requests on each server before reading anything
11941203
dead_servers = []

tests/test_memcache.py

+27
Original file line numberDiff line numberDiff line change
@@ -252,5 +252,32 @@ def test_touch_unexpected_reply(self, mock_readline, mock_send_cmd):
252252
)
253253

254254

255+
class TestMemcacheEncoder(unittest.TestCase):
256+
def setUp(self):
257+
# TODO(): unix socket server stuff
258+
servers = ["127.0.0.1:11211"]
259+
self.mc = Client(servers, debug=1, key_encoder=self.encoder)
260+
261+
def tearDown(self):
262+
self.mc.flush_all()
263+
self.mc.disconnect_all()
264+
265+
def encoder(self, key):
266+
return key.lower()
267+
268+
def check_setget(self, key, val, noreply=False):
269+
self.mc.set(key, val, noreply=noreply)
270+
newval = self.mc.get(key)
271+
self.assertEqual(newval, val)
272+
273+
def test_setget(self):
274+
self.check_setget("a_string", "some random string")
275+
self.check_setget("A_String2", "some random string")
276+
self.check_setget("an_integer", 42)
277+
self.assertEqual("some random string", self.mc.get("A_String"))
278+
self.assertEqual("some random string", self.mc.get("a_sTRing2"))
279+
self.assertEqual(42, self.mc.get("An_Integer"))
280+
281+
255282
if __name__ == '__main__':
256283
unittest.main()

0 commit comments

Comments
 (0)