Skip to content

Commit aae2c09

Browse files
committed
refactoring classes and rewriting redis backend
1 parent a17f164 commit aae2c09

File tree

1 file changed

+217
-18
lines changed

1 file changed

+217
-18
lines changed

pybloom/tests/__init__.py

Lines changed: 217 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,193 @@
11
import unittest
22

33
import redis
4-
from hamcrest import assert_that, equal_to, raises, is_, instance_of, empty, is_not, greater_than
4+
from fakeredis import FakeStrictRedis
5+
from hamcrest import assert_that, equal_to, raises, is_, instance_of, greater_than, empty, is_not
56
from mock import mock
6-
from mockredis import mock_strict_redis_client
77
from redis import StrictRedis
8+
from redis.exceptions import LockError
9+
from redis.lock import LuaLock
810

911
from pybloom.src.backends.bitarraybackend import BitArrayBackend
1012
from pybloom.src.backends.numpybackend import NumpyBackend
1113
from pybloom.src.backends.redisbackend import RedisBackend, RedisProxy
1214
from pybloom.src.bloomfilter import BloomFilter, BloomFilterException, Options, Size, size_to_human_format
1315

16+
LUA_ADD_SCRIPT = """
17+
18+
-- https://gist.github.com/tylerneylon/59f4bcf316be525b30ab
19+
local json = {}
20+
21+
local function kind_of(obj)
22+
if type(obj) ~= 'table' then return type(obj) end
23+
local i = 1
24+
for _ in pairs(obj) do
25+
if obj[i] ~= nil then i = i + 1 else return 'table' end
26+
end
27+
if i == 1 then return 'table' else return 'array' end
28+
end
29+
30+
local function escape_str(s)
31+
local in_char = { '\\\\', '\\"', '/', '\b', '\f', '\\\n', '\\\r', '\t' }
32+
local out_char = {'\\\\', '\\"', '/', 'b', 'f', 'n', 'r', 't'}
33+
34+
for i, c in ipairs(in_char) do
35+
s = s:gsub(c, '\\\\' .. out_char[i])
36+
end
37+
return s
38+
end
39+
40+
41+
local function skip_delim(str, pos, delim, err_if_missing)
42+
pos = pos + #str:match('^%s*', pos)
43+
if str:sub(pos, pos) ~= delim then
44+
if err_if_missing then
45+
error('Expected ' .. delim .. ' near position ' .. pos)
46+
end
47+
return pos, false
48+
end
49+
return pos + 1, true
50+
end
51+
52+
local function parse_str_val(str, pos, val)
53+
val = val or ''
54+
local early_end_error = 'End of input found while parsing string.'
55+
if pos > #str then error(early_end_error) end
56+
local c = str:sub(pos, pos)
57+
if c == '\\"' then return val, pos + 1 end
58+
if c ~= '\\\\' then return parse_str_val(str, pos + 1, val .. c) end
59+
-- We must have a \ character.
60+
local esc_map = {b = '\b', f = '\f', n = '\\\n', r = '\\\r', t = '\t'}
61+
local nextc = str:sub(pos + 1, pos + 1)
62+
if not nextc then error(early_end_error) end
63+
return parse_str_val(str, pos + 2, val .. (esc_map[nextc] or nextc))
64+
end
65+
66+
local function parse_num_val(str, pos)
67+
local num_str = str:match('^-?%d+%.?%d*[eE]?[+-]?%d*', pos)
68+
local val = tonumber(num_str)
69+
if not val then error('Error parsing number at position ' .. pos .. '.') end
70+
return val, pos + #num_str
71+
end
72+
73+
function json.stringify(obj, as_key)
74+
local s = {} -- We'll build the string as an array of strings to be concatenated.
75+
local kind = kind_of(obj) -- This is 'array' if it's an array or type(obj) otherwise.
76+
if kind == 'array' then
77+
if as_key then error('Can\\\'t encode array as key.') end
78+
s[#s + 1] = '['
79+
for i, val in ipairs(obj) do
80+
if i > 1 then s[#s + 1] = ', ' end
81+
s[#s + 1] = json.stringify(val)
82+
end
83+
s[#s + 1] = ']'
84+
elseif kind == 'table' then
85+
if as_key then error('Can\\\'t encode table as key.') end
86+
s[#s + 1] = '{'
87+
for k, v in pairs(obj) do
88+
if #s > 1 then s[#s + 1] = ', ' end
89+
s[#s + 1] = json.stringify(k, true)
90+
s[#s + 1] = ':'
91+
s[#s + 1] = json.stringify(v)
92+
end
93+
s[#s + 1] = '}'
94+
elseif kind == 'string' then
95+
return '"' .. escape_str(obj) .. '"'
96+
elseif kind == 'number' then
97+
if as_key then return '"' .. tostring(obj) .. '"' end
98+
return tostring(obj)
99+
elseif kind == 'boolean' then
100+
return tostring(obj)
101+
elseif kind == 'nil' then
102+
return 'null'
103+
else
104+
error('Unjsonifiable type: ' .. kind .. '.')
105+
end
106+
return table.concat(s)
107+
end
108+
109+
json.null = {}
110+
111+
function json.parse(str, pos, end_delim)
112+
pos = pos or 1
113+
if pos > #str then error('Reached unexpected end of input.') end
114+
local pos = pos + #str:match('^%s*', pos) -- Skip whitespace.
115+
local first = str:sub(pos, pos)
116+
if first == '{' then -- Parse an object.
117+
local obj, key, delim_found = {}, true, true
118+
pos = pos + 1
119+
while true do
120+
key, pos = json.parse(str, pos, '}')
121+
if key == nil then return obj, pos end
122+
if not delim_found then error('Comma missing between object items.') end
123+
pos = skip_delim(str, pos, ':', true) -- true -> error if missing.
124+
obj[key], pos = json.parse(str, pos)
125+
pos, delim_found = skip_delim(str, pos, ',')
126+
end
127+
elseif first == '[' then -- Parse an array.
128+
local arr, val, delim_found = {}, true, true
129+
pos = pos + 1
130+
while true do
131+
val, pos = json.parse(str, pos, ']')
132+
if val == nil then return arr, pos end
133+
if not delim_found then error('Comma missing between array items.') end
134+
arr[#arr + 1] = val
135+
pos, delim_found = skip_delim(str, pos, ',')
136+
end
137+
elseif first == '"' then -- Parse a string.
138+
return parse_str_val(str, pos + 1)
139+
elseif first == '-' or first:match('%d') then -- Parse a number.
140+
return parse_num_val(str, pos)
141+
elseif first == end_delim then -- End of an object or array.
142+
return nil, pos + 1
143+
else -- Parse true, false, or null.
144+
local literals = {['true'] = true, ['false'] = false, ['null'] = json.null}
145+
for lit_str, lit_val in pairs(literals) do
146+
local lit_end = pos + #lit_str - 1
147+
if str:sub(pos, lit_end) == lit_str then return lit_val, lit_end + 1 end
148+
end
149+
local pos_info_str = 'position ' .. pos .. ': ' .. str:sub(pos, pos + 10)
150+
error('Invalid json syntax starting at ' .. pos_info_str)
151+
end
152+
end
153+
154+
local capacity = tonumber(redis.call('HGET', KEYS[1], 'capacity'))
155+
local filter_size = tonumber(redis.call('HGET', KEYS[1], 'filter_size'))
156+
local hash_size = tonumber(redis.call('HGET', KEYS[1], 'hash_size'))
157+
158+
-- This means that filter has been reset
159+
if capacity == nil or filter_size == nil then
160+
return false
161+
end
162+
163+
if capacity >= filter_size then
164+
return false
165+
end
166+
167+
local sum = 0
168+
for i=1, #ARGV do
169+
local args = json.parse(ARGV[i])
170+
if redis.call('GETBIT', args['key'], args['offset']) == 1 then
171+
sum = sum + 1
172+
end
173+
174+
redis.call('SETBIT', args['key'], args['offset'], 1)
175+
end
176+
177+
-- Only if the element don't exists, increase the capacity (i.e. add it)
178+
if sum ~= hash_size then
179+
capacity = capacity + 1
180+
redis.call('HSET', KEYS[1], 'capacity', capacity)
181+
end
182+
183+
return capacity
184+
"""
185+
14186

15187
class MockRedisProxy(object):
16188
def __init__(self, *args, **kwargs):
17189
self._connection = mock.patch('pybloom.src.backends.redisbackend.redis.StrictRedis',
18-
new_callable=mock_strict_redis_client).start()
190+
new_callable=FakeStrictRedis).start()
19191

20192
def as_pipeline(self):
21193
return self._connection.pipeline()
@@ -35,8 +207,8 @@ class testRedisProxy(unittest.TestCase):
35207
def setUp(self):
36208
self._proxy = RedisProxy('')
37209
self._proxy._connection = mock.patch('pybloom.src.backends.redisbackend.redis.StrictRedis',
38-
new_callable=mock_strict_redis_client).start()
39-
self._proxy._connection.reset = self._proxy._connection.pipeline()._reset # bad signature in mockredis
210+
new_callable=FakeStrictRedis).start()
211+
self._proxy._connection.reset = self._proxy._connection.pipeline().reset # bad signature in mockredis
40212

41213
@mock.patch('pybloom.src.backends.redisbackend.redis.StrictRedis', spec=StrictRedis)
42214
def testRetryConnectionError(self, rediss):
@@ -60,7 +232,7 @@ def testRetryTimeoutError(self, rediss):
60232

61233
def testNoRetry(self):
62234
self._proxy.set('key', 'hello')
63-
assert_that(self._proxy.ping(), is_(b'PONG'))
235+
assert_that(self._proxy.ping(), is_(True))
64236
assert_that(self._proxy.get('key'), equal_to(b'hello'))
65237

66238
def testPipeline(self):
@@ -74,8 +246,9 @@ def testPipeline(self):
74246
class testRedisBackend(unittest.TestCase):
75247
def setUp(self):
76248
with mock.patch('pybloom.src.backends.redisbackend.RedisProxy', new=MockRedisProxy):
77-
self._backend = RedisBackend(array_size=10, hash_size=3, redis_connection='')
78-
# # self._backend._redis = MockRedisProxy()
249+
self._backend = RedisBackend(array_size=10, hash_size=3, redis_connection='', filter_size=5)
250+
self._backend._lua_add = self._backend._redis.register_script(LUA_ADD_SCRIPT)
251+
# self._backend._lua_add = self._backend._redis.register_script(LUA_ADD_KEY)
79252

80253
def testRightOffset(self):
81254
# First, a simple offset (first 2^32)
@@ -88,16 +261,41 @@ def testRightOffset(self):
88261
assert_that(k, equal_to(2)) # 2 offsets in total (1: [0, 2^32-1], 2: [2^32, 2^33 - 1])
89262
assert_that(offset, equal_to(2 ** 33 - 1))
90263

91-
def testResetWithNoData(self):
92-
self._backend.reset()
93-
assert_that(list(self._backend._redis.scan_iter('bloomfilter:*')), is_(empty()))
264+
@mock.patch('pybloom.src.backends.redisbackend.lock', spec=LuaLock)
265+
def testMetadataError(self, mock_lock):
266+
mock_lock.side_effect = LockError
267+
with mock.patch('pybloom.src.backends.redisbackend.RedisProxy', new=MockRedisProxy):
268+
with self.assertRaises(BloomFilterException) as cm:
269+
self._backend = RedisBackend(array_size=10, hash_size=3, redis_connection='', filter_size=5)
94270

95-
def testResetWithData(self):
271+
assert_that(str(cm.exception), equal_to("Cannot retrieve metadata from redis. Seems another process has "
272+
"acquired the lock and did not released. Check if "
273+
"'bloom_filter_lock' key is in your redis server."))
274+
275+
# @mock.patch('pybloom.src.backends.redisbackend.lock', spec=LuaLock)
276+
def testMetadataOk(self):
277+
# Check we dont have any metadata yet
278+
response = {key.decode(): val.decode() for key, val in
279+
self._backend._redis.hgetall(self._backend._metadata_key).items()}
280+
assert_that(response, equal_to(dict(array_size='10', hash_size='3', filter_size='5', capacity='0')))
281+
282+
# Add data
96283
self._backend.add(4)
97-
assert_that(list(self._backend._redis.scan_iter('bloomfilter:*')), is_not(empty()))
284+
285+
# Check metadata again
286+
response = {key.decode(): val.decode() for key, val in
287+
self._backend._redis.hgetall(self._backend._metadata_key).items()}
288+
assert_that(response, equal_to(dict(array_size='10', hash_size='3', filter_size='5', capacity='1')))
289+
290+
def testResetWithData(self):
291+
self._backend.add(45)
292+
assert_that(list(self._backend._redis.scan_iter('bloom_filter:*')), is_not(empty()))
98293

99294
self._backend.reset()
100-
assert_that(list(self._backend._redis.scan_iter('bloomfilter:*')), is_(empty()))
295+
assert_that(list(self._backend._redis.scan_iter('bloom_filter:*')), is_(empty()))
296+
assert_that(len(self._backend), is_(0))
297+
response = self._backend._redis.hgetall(self._backend._metadata_key).items()
298+
assert_that(response, is_(empty()))
101299

102300
def testAddandCheck(self):
103301
self._backend.add('house')
@@ -106,9 +304,10 @@ def testAddandCheck(self):
106304
self._backend += 'horse'
107305
assert_that('horse' in self._backend, is_(True))
108306

307+
109308
class testNumpyBackend(unittest.TestCase):
110309
def setUp(self):
111-
self._backend = NumpyBackend(array_size=10, hash_size=3)
310+
self._backend = NumpyBackend(array_size=10, hash_size=3, filter_size=5)
112311

113312
def testResetWithNoData(self):
114313
self._backend.reset()
@@ -128,12 +327,12 @@ def testAddandCheck(self):
128327
self._backend += 'horse'
129328
assert_that('horse' in self._backend, is_(True))
130329

131-
assert_that(self._backend._capacity, equal_to(2))
330+
assert_that(len(self._backend), equal_to(2))
132331

133332

134333
class testBitArrayBackend(unittest.TestCase):
135334
def setUp(self):
136-
self._backend = BitArrayBackend(array_size=10, hash_size=3)
335+
self._backend = BitArrayBackend(array_size=10, hash_size=3, filter_size=5)
137336

138337
def testResetWithNoData(self):
139338
self._backend.reset()
@@ -153,7 +352,7 @@ def testAddandCheck(self):
153352
self._backend += 'horse'
154353
assert_that('horse' in self._backend, is_(True))
155354

156-
assert_that(self._backend._capacity, equal_to(2))
355+
assert_that(len(self._backend), equal_to(2))
157356

158357

159358
class testBloomFilter(unittest.TestCase):

0 commit comments

Comments
 (0)