11import unittest
22
33import redis
4- from hamcrest import assert_that , equal_to , raises , is_ , instance_of , empty , is_not , greater_than
4+ from fakeredis import FakeStrictRedis
5+ from hamcrest import assert_that , equal_to , raises , is_ , instance_of , greater_than , empty , is_not
56from mock import mock
6- from mockredis import mock_strict_redis_client
77from redis import StrictRedis
8+ from redis .exceptions import LockError
9+ from redis .lock import LuaLock
810
911from pybloom .src .backends .bitarraybackend import BitArrayBackend
1012from pybloom .src .backends .numpybackend import NumpyBackend
1113from pybloom .src .backends .redisbackend import RedisBackend , RedisProxy
1214from pybloom .src .bloomfilter import BloomFilter , BloomFilterException , Options , Size , size_to_human_format
1315
16+ LUA_ADD_SCRIPT = """
17+
18+ -- https://gist.github.com/tylerneylon/59f4bcf316be525b30ab
19+ local json = {}
20+
21+ local function kind_of(obj)
22+ if type(obj) ~= 'table' then return type(obj) end
23+ local i = 1
24+ for _ in pairs(obj) do
25+ if obj[i] ~= nil then i = i + 1 else return 'table' end
26+ end
27+ if i == 1 then return 'table' else return 'array' end
28+ end
29+
30+ local function escape_str(s)
31+ local in_char = { '\\ \\ ', '\\ "', '/', '\b ', '\f ', '\\ \n ', '\\ \r ', '\t ' }
32+ local out_char = {'\\ \\ ', '\\ "', '/', 'b', 'f', 'n', 'r', 't'}
33+
34+ for i, c in ipairs(in_char) do
35+ s = s:gsub(c, '\\ \\ ' .. out_char[i])
36+ end
37+ return s
38+ end
39+
40+
41+ local function skip_delim(str, pos, delim, err_if_missing)
42+ pos = pos + #str:match('^%s*', pos)
43+ if str:sub(pos, pos) ~= delim then
44+ if err_if_missing then
45+ error('Expected ' .. delim .. ' near position ' .. pos)
46+ end
47+ return pos, false
48+ end
49+ return pos + 1, true
50+ end
51+
52+ local function parse_str_val(str, pos, val)
53+ val = val or ''
54+ local early_end_error = 'End of input found while parsing string.'
55+ if pos > #str then error(early_end_error) end
56+ local c = str:sub(pos, pos)
57+ if c == '\\ "' then return val, pos + 1 end
58+ if c ~= '\\ \\ ' then return parse_str_val(str, pos + 1, val .. c) end
59+ -- We must have a \ character.
60+ local esc_map = {b = '\b ', f = '\f ', n = '\\ \n ', r = '\\ \r ', t = '\t '}
61+ local nextc = str:sub(pos + 1, pos + 1)
62+ if not nextc then error(early_end_error) end
63+ return parse_str_val(str, pos + 2, val .. (esc_map[nextc] or nextc))
64+ end
65+
66+ local function parse_num_val(str, pos)
67+ local num_str = str:match('^-?%d+%.?%d*[eE]?[+-]?%d*', pos)
68+ local val = tonumber(num_str)
69+ if not val then error('Error parsing number at position ' .. pos .. '.') end
70+ return val, pos + #num_str
71+ end
72+
73+ function json.stringify(obj, as_key)
74+ local s = {} -- We'll build the string as an array of strings to be concatenated.
75+ local kind = kind_of(obj) -- This is 'array' if it's an array or type(obj) otherwise.
76+ if kind == 'array' then
77+ if as_key then error('Can\\ \' t encode array as key.') end
78+ s[#s + 1] = '['
79+ for i, val in ipairs(obj) do
80+ if i > 1 then s[#s + 1] = ', ' end
81+ s[#s + 1] = json.stringify(val)
82+ end
83+ s[#s + 1] = ']'
84+ elseif kind == 'table' then
85+ if as_key then error('Can\\ \' t encode table as key.') end
86+ s[#s + 1] = '{'
87+ for k, v in pairs(obj) do
88+ if #s > 1 then s[#s + 1] = ', ' end
89+ s[#s + 1] = json.stringify(k, true)
90+ s[#s + 1] = ':'
91+ s[#s + 1] = json.stringify(v)
92+ end
93+ s[#s + 1] = '}'
94+ elseif kind == 'string' then
95+ return '"' .. escape_str(obj) .. '"'
96+ elseif kind == 'number' then
97+ if as_key then return '"' .. tostring(obj) .. '"' end
98+ return tostring(obj)
99+ elseif kind == 'boolean' then
100+ return tostring(obj)
101+ elseif kind == 'nil' then
102+ return 'null'
103+ else
104+ error('Unjsonifiable type: ' .. kind .. '.')
105+ end
106+ return table.concat(s)
107+ end
108+
109+ json.null = {}
110+
111+ function json.parse(str, pos, end_delim)
112+ pos = pos or 1
113+ if pos > #str then error('Reached unexpected end of input.') end
114+ local pos = pos + #str:match('^%s*', pos) -- Skip whitespace.
115+ local first = str:sub(pos, pos)
116+ if first == '{' then -- Parse an object.
117+ local obj, key, delim_found = {}, true, true
118+ pos = pos + 1
119+ while true do
120+ key, pos = json.parse(str, pos, '}')
121+ if key == nil then return obj, pos end
122+ if not delim_found then error('Comma missing between object items.') end
123+ pos = skip_delim(str, pos, ':', true) -- true -> error if missing.
124+ obj[key], pos = json.parse(str, pos)
125+ pos, delim_found = skip_delim(str, pos, ',')
126+ end
127+ elseif first == '[' then -- Parse an array.
128+ local arr, val, delim_found = {}, true, true
129+ pos = pos + 1
130+ while true do
131+ val, pos = json.parse(str, pos, ']')
132+ if val == nil then return arr, pos end
133+ if not delim_found then error('Comma missing between array items.') end
134+ arr[#arr + 1] = val
135+ pos, delim_found = skip_delim(str, pos, ',')
136+ end
137+ elseif first == '"' then -- Parse a string.
138+ return parse_str_val(str, pos + 1)
139+ elseif first == '-' or first:match('%d') then -- Parse a number.
140+ return parse_num_val(str, pos)
141+ elseif first == end_delim then -- End of an object or array.
142+ return nil, pos + 1
143+ else -- Parse true, false, or null.
144+ local literals = {['true'] = true, ['false'] = false, ['null'] = json.null}
145+ for lit_str, lit_val in pairs(literals) do
146+ local lit_end = pos + #lit_str - 1
147+ if str:sub(pos, lit_end) == lit_str then return lit_val, lit_end + 1 end
148+ end
149+ local pos_info_str = 'position ' .. pos .. ': ' .. str:sub(pos, pos + 10)
150+ error('Invalid json syntax starting at ' .. pos_info_str)
151+ end
152+ end
153+
154+ local capacity = tonumber(redis.call('HGET', KEYS[1], 'capacity'))
155+ local filter_size = tonumber(redis.call('HGET', KEYS[1], 'filter_size'))
156+ local hash_size = tonumber(redis.call('HGET', KEYS[1], 'hash_size'))
157+
158+ -- This means that filter has been reset
159+ if capacity == nil or filter_size == nil then
160+ return false
161+ end
162+
163+ if capacity >= filter_size then
164+ return false
165+ end
166+
167+ local sum = 0
168+ for i=1, #ARGV do
169+ local args = json.parse(ARGV[i])
170+ if redis.call('GETBIT', args['key'], args['offset']) == 1 then
171+ sum = sum + 1
172+ end
173+
174+ redis.call('SETBIT', args['key'], args['offset'], 1)
175+ end
176+
177+ -- Only if the element don't exists, increase the capacity (i.e. add it)
178+ if sum ~= hash_size then
179+ capacity = capacity + 1
180+ redis.call('HSET', KEYS[1], 'capacity', capacity)
181+ end
182+
183+ return capacity
184+ """
185+
14186
15187class MockRedisProxy (object ):
16188 def __init__ (self , * args , ** kwargs ):
17189 self ._connection = mock .patch ('pybloom.src.backends.redisbackend.redis.StrictRedis' ,
18- new_callable = mock_strict_redis_client ).start ()
190+ new_callable = FakeStrictRedis ).start ()
19191
20192 def as_pipeline (self ):
21193 return self ._connection .pipeline ()
@@ -35,8 +207,8 @@ class testRedisProxy(unittest.TestCase):
35207 def setUp (self ):
36208 self ._proxy = RedisProxy ('' )
37209 self ._proxy ._connection = mock .patch ('pybloom.src.backends.redisbackend.redis.StrictRedis' ,
38- new_callable = mock_strict_redis_client ).start ()
39- self ._proxy ._connection .reset = self ._proxy ._connection .pipeline ()._reset # bad signature in mockredis
210+ new_callable = FakeStrictRedis ).start ()
211+ self ._proxy ._connection .reset = self ._proxy ._connection .pipeline ().reset # bad signature in mockredis
40212
41213 @mock .patch ('pybloom.src.backends.redisbackend.redis.StrictRedis' , spec = StrictRedis )
42214 def testRetryConnectionError (self , rediss ):
@@ -60,7 +232,7 @@ def testRetryTimeoutError(self, rediss):
60232
61233 def testNoRetry (self ):
62234 self ._proxy .set ('key' , 'hello' )
63- assert_that (self ._proxy .ping (), is_ (b'PONG' ))
235+ assert_that (self ._proxy .ping (), is_ (True ))
64236 assert_that (self ._proxy .get ('key' ), equal_to (b'hello' ))
65237
66238 def testPipeline (self ):
@@ -74,8 +246,9 @@ def testPipeline(self):
74246class testRedisBackend (unittest .TestCase ):
75247 def setUp (self ):
76248 with mock .patch ('pybloom.src.backends.redisbackend.RedisProxy' , new = MockRedisProxy ):
77- self ._backend = RedisBackend (array_size = 10 , hash_size = 3 , redis_connection = '' )
78- # # self._backend._redis = MockRedisProxy()
249+ self ._backend = RedisBackend (array_size = 10 , hash_size = 3 , redis_connection = '' , filter_size = 5 )
250+ self ._backend ._lua_add = self ._backend ._redis .register_script (LUA_ADD_SCRIPT )
251+ # self._backend._lua_add = self._backend._redis.register_script(LUA_ADD_KEY)
79252
80253 def testRightOffset (self ):
81254 # First, a simple offset (first 2^32)
@@ -88,16 +261,41 @@ def testRightOffset(self):
88261 assert_that (k , equal_to (2 )) # 2 offsets in total (1: [0, 2^32-1], 2: [2^32, 2^33 - 1])
89262 assert_that (offset , equal_to (2 ** 33 - 1 ))
90263
91- def testResetWithNoData (self ):
92- self ._backend .reset ()
93- assert_that (list (self ._backend ._redis .scan_iter ('bloomfilter:*' )), is_ (empty ()))
264+ @mock .patch ('pybloom.src.backends.redisbackend.lock' , spec = LuaLock )
265+ def testMetadataError (self , mock_lock ):
266+ mock_lock .side_effect = LockError
267+ with mock .patch ('pybloom.src.backends.redisbackend.RedisProxy' , new = MockRedisProxy ):
268+ with self .assertRaises (BloomFilterException ) as cm :
269+ self ._backend = RedisBackend (array_size = 10 , hash_size = 3 , redis_connection = '' , filter_size = 5 )
94270
95- def testResetWithData (self ):
271+ assert_that (str (cm .exception ), equal_to ("Cannot retrieve metadata from redis. Seems another process has "
272+ "acquired the lock and did not released. Check if "
273+ "'bloom_filter_lock' key is in your redis server." ))
274+
275+ # @mock.patch('pybloom.src.backends.redisbackend.lock', spec=LuaLock)
276+ def testMetadataOk (self ):
277+ # Check we dont have any metadata yet
278+ response = {key .decode (): val .decode () for key , val in
279+ self ._backend ._redis .hgetall (self ._backend ._metadata_key ).items ()}
280+ assert_that (response , equal_to (dict (array_size = '10' , hash_size = '3' , filter_size = '5' , capacity = '0' )))
281+
282+ # Add data
96283 self ._backend .add (4 )
97- assert_that (list (self ._backend ._redis .scan_iter ('bloomfilter:*' )), is_not (empty ()))
284+
285+ # Check metadata again
286+ response = {key .decode (): val .decode () for key , val in
287+ self ._backend ._redis .hgetall (self ._backend ._metadata_key ).items ()}
288+ assert_that (response , equal_to (dict (array_size = '10' , hash_size = '3' , filter_size = '5' , capacity = '1' )))
289+
290+ def testResetWithData (self ):
291+ self ._backend .add (45 )
292+ assert_that (list (self ._backend ._redis .scan_iter ('bloom_filter:*' )), is_not (empty ()))
98293
99294 self ._backend .reset ()
100- assert_that (list (self ._backend ._redis .scan_iter ('bloomfilter:*' )), is_ (empty ()))
295+ assert_that (list (self ._backend ._redis .scan_iter ('bloom_filter:*' )), is_ (empty ()))
296+ assert_that (len (self ._backend ), is_ (0 ))
297+ response = self ._backend ._redis .hgetall (self ._backend ._metadata_key ).items ()
298+ assert_that (response , is_ (empty ()))
101299
102300 def testAddandCheck (self ):
103301 self ._backend .add ('house' )
@@ -106,9 +304,10 @@ def testAddandCheck(self):
106304 self ._backend += 'horse'
107305 assert_that ('horse' in self ._backend , is_ (True ))
108306
307+
109308class testNumpyBackend (unittest .TestCase ):
110309 def setUp (self ):
111- self ._backend = NumpyBackend (array_size = 10 , hash_size = 3 )
310+ self ._backend = NumpyBackend (array_size = 10 , hash_size = 3 , filter_size = 5 )
112311
113312 def testResetWithNoData (self ):
114313 self ._backend .reset ()
@@ -128,12 +327,12 @@ def testAddandCheck(self):
128327 self ._backend += 'horse'
129328 assert_that ('horse' in self ._backend , is_ (True ))
130329
131- assert_that (self ._backend . _capacity , equal_to (2 ))
330+ assert_that (len ( self ._backend ) , equal_to (2 ))
132331
133332
134333class testBitArrayBackend (unittest .TestCase ):
135334 def setUp (self ):
136- self ._backend = BitArrayBackend (array_size = 10 , hash_size = 3 )
335+ self ._backend = BitArrayBackend (array_size = 10 , hash_size = 3 , filter_size = 5 )
137336
138337 def testResetWithNoData (self ):
139338 self ._backend .reset ()
@@ -153,7 +352,7 @@ def testAddandCheck(self):
153352 self ._backend += 'horse'
154353 assert_that ('horse' in self ._backend , is_ (True ))
155354
156- assert_that (self ._backend . _capacity , equal_to (2 ))
355+ assert_that (len ( self ._backend ) , equal_to (2 ))
157356
158357
159358class testBloomFilter (unittest .TestCase ):
0 commit comments