Skip to content

Commit

Permalink
Implementing new methods for scrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
tcatrain committed Nov 29, 2018
1 parent 7ad1551 commit e68e930
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
28 changes: 26 additions & 2 deletions elasticmock/fake_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from elasticsearch.client.utils import query_params
from elasticsearch.exceptions import NotFoundError

from elasticmock.utilities import get_random_id
from elasticmock.utilities import get_random_id, get_random_scroll_id


PY3 = sys.version_info[0] == 3
Expand All @@ -20,6 +20,7 @@ class FakeElasticsearch(Elasticsearch):

def __init__(self, hosts=None, transport_class=None, **kwargs):
self.__documents_dict = {}
self.__scrolls = {}

@query_params()
def ping(self, params=None):
Expand Down Expand Up @@ -183,10 +184,33 @@ def search(self, index=None, doc_type=None, body=None, params=None):
for match in matches:
match['_score'] = 1.0
hits.append(match)
result['hits']['hits'] = hits

if 'scroll' in params:
result['_scroll_id'] = str(get_random_scroll_id())
params['size'] = int(params.get('size') if 'size' in params else 10)
params['from'] = int(params.get('from') + params.get('size') if 'from' in params else 0)
self.__scrolls[result.get('_scroll_id')] = {
'index' : index,
'doc_type' : doc_type,
'body' : body,
'params' : params
}
hits = hits[params.get('from'):params.get('from') + params.get('size')]

result['hits']['hits'] = hits
return result

@query_params('scroll')
def scroll(self, scroll_id, params=None):
scroll = self.__scrolls.pop(scroll_id)
result = self.search(
index = scroll.get('index'),
doc_type = scroll.get('doc_type'),
body = scroll.get('body'),
params = scroll.get('params')
)
return result

@query_params('consistency', 'parent', 'refresh', 'replication', 'routing',
'timeout', 'version', 'version_type')
def delete(self, index, doc_type, id, params=None):
Expand Down
5 changes: 5 additions & 0 deletions elasticmock/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

import random
import string
import base64

DEFAULT_ELASTICSEARCH_ID_SIZE = 20
CHARSET_FOR_ELASTICSEARCH_ID = string.ascii_letters + string.digits

DEFAULT_ELASTICSEARCH_SEARCHRESULTPHASE_COUNT = 6

def get_random_id(size=DEFAULT_ELASTICSEARCH_ID_SIZE):
return ''.join(random.choice(CHARSET_FOR_ELASTICSEARCH_ID) for _ in range(size))

def get_random_scroll_id(size=DEFAULT_ELASTICSEARCH_SEARCHRESULTPHASE_COUNT):
return base64.b64encode(''.join(get_random_id() for _ in range(size)).encode())
29 changes: 28 additions & 1 deletion tests/test_elasticmock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-

import unittest

import elasticsearch
from elasticsearch.exceptions import NotFoundError

Expand Down Expand Up @@ -242,6 +241,34 @@ def test_doc_type_can_be_list(self):
result = self.es.search(doc_type=doc_types[:2])
self.assertEqual(count_per_doc_type * 2, result.get('hits').get('total'))

def test_search_with_scroll_param(self):
for _ in range(100):
self.es.index(index='groups', doc_type='groups', body={'budget': 1000})

result = self.es.search(index='groups', params={'scroll' : '1m', 'size' : 30})
self.assertNotEqual(None, result.get('_scroll_id', None))
self.assertEqual(30, len(result.get('hits').get('hits')))
self.assertEqual(100, result.get('hits').get('total'))

def test_scrolling(self):
for _ in range(100):
self.es.index(index='groups', doc_type='groups', body={'budget': 1000})

result = self.es.search(index='groups', params={'scroll' : '1m', 'size' : 30})
self.assertNotEqual(None, result.get('_scroll_id', None))
self.assertEqual(30, len(result.get('hits').get('hits')))
self.assertEqual(100, result.get('hits').get('total'))

for _ in range(2):
result = self.es.scroll(scroll_id = result.get('_scroll_id'), scroll = '1m')
self.assertNotEqual(None, result.get('_scroll_id', None))
self.assertEqual(30, len(result.get('hits').get('hits')))
self.assertEqual(100, result.get('hits').get('total'))

result = self.es.scroll(scroll_id = result.get('_scroll_id'), scroll = '1m')
self.assertNotEqual(None, result.get('_scroll_id', None))
self.assertEqual(10, len(result.get('hits').get('hits')))
self.assertEqual(100, result.get('hits').get('total'))

if __name__ == '__main__':
unittest.main()

0 comments on commit e68e930

Please sign in to comment.