Skip to content

Commit

Permalink
Refactor BM25 and Search Application search.
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Jul 19, 2023
1 parent 37a276b commit e3550e4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 38 deletions.
13 changes: 6 additions & 7 deletions wikipedia/challenges/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
"name": "refresh-after-index",
"operation": "refresh-after-index"
},
{
"name": "query-string-search",
"operation": "query-string-search",
"clients": {{search_clients | default(10)}},
"warmup-iterations": 500
},
{
"name": "create-default-search-application",
"operation": "create-default-search-application"
Expand All @@ -33,13 +39,6 @@
"name": "default-search-application-search",
"operation": "default-search-application-search",
"clients": {{search_clients | default(10)}}
},
{
"name": "query-string-search",
"operation": "query-string-search",
"clients": {{search_clients | default(10)}},
"warmup-iterations": 500,
"iterations": 10000
}
]
}
2 changes: 1 addition & 1 deletion wikipedia/operations/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
"operation-type": "search",
"param-source": "query-string-search",
"size" : {{search_size | default(20)}},
"search_fields" : "title"
"search-fields" : "{{search_fields | default("*")}}"
}
49 changes: 19 additions & 30 deletions wikipedia/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def query_iterator(k: int) -> Iterator[str]:
probabilities = [float(probability) for _, probability in queries_with_probabilities]

for query in random.choices(queries, weights=probabilities, k=k):
yield query
# remove special chars from the query + lowercase
yield re.sub("[^0-9a-zA-Z]+", " ", query).lower()


class SearchApplicationParams:
Expand All @@ -49,10 +50,9 @@ def params(self):
}


class SearchApplicationSearchParamSource(ParamSource):
class QueryIteratorParamSource(ParamSource):
def __init__(self, track, params, **kwargs):
super().__init__(track, params, **kwargs)
self.search_application_params = SearchApplicationParams(track, params)
self._queries_iterator = None

def size(self):
Expand All @@ -64,45 +64,34 @@ def partition(self, partition_index, total_partitions):
self._queries_iterator = query_iterator(partition_size)
return self


class SearchApplicationSearchParamSource(QueryIteratorParamSource):
def __init__(self, track, params, **kwargs):
super().__init__(track, params, **kwargs)
self.search_application_params = SearchApplicationParams(track, params)

def params(self):
# remover special chars from the query + lowercase
query = re.sub("[^0-9a-zA-Z]+", " ", next(self._queries_iterator)).lower()
query = next(self._queries_iterator)
return {
"method": "POST",
"path": f"{SEARCH_APPLICATION_ROOT_ENDPOINT}/{self.search_application_params.name}/_search",
"body": {"params": {"query_string": query}},
"body": {
"params": {
"query_string": query,
},
},
}


class QueryParamSource:
class QueryParamSource(QueryIteratorParamSource):
def __init__(self, track, params, **kwargs):
if len(track.indices) == 1:
default_index = track.indices[0].name
if len(track.indices[0].types) == 1:
default_type = track.indices[0].types[0].name
else:
default_type = None
else:
default_index = "_all"
default_type = None

self._index_name = params.get("index", default_index)
self._type_name = params.get("type", default_type)
super().__init__(track, params, **kwargs)
self._index_name = params.get("index", track.indices[0].name if len(track.indices) == 1 else "_all")
self._cache = params.get("cache", False)
self._params = params
self.infinite = True
self._queries_iterator = None

def partition(self, partition_index, total_partitions):
if self._queries_iterator is None:
partition_size = math.ceil(self._params.get("iterations", 10000) / total_partitions)
self._queries_iterator = query_iterator(partition_size)
return self

def params(self):
query_str = re.sub("[^0-9a-zA-Z]+", " ", next(self._queries_iterator)).lower()
result = {
"body": {"query": {"query_string": {"query": query_str, "default_field": self._params["search_fields"]}}},
"body": {"query": {"query_string": {"query": next(self._queries_iterator), "default_field": self._params["search-fields"]}}},
"size": self._params["size"],
"index": self._index_name,
}
Expand Down

0 comments on commit e3550e4

Please sign in to comment.