Skip to content

Commit

Permalink
reset continuation token (Azure#18772)
Browse files Browse the repository at this point in the history
* reset continuation token

* add tests

* update
  • Loading branch information
xiangyan99 authored May 18, 2021
1 parent ef56f99 commit 20a3129
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,23 @@ def _extract_data_cb(self, response): # pylint:disable=no-self-use

@_ensure_response
def get_facets(self):
self.continuation_token = None
facets = self._response.facets
if facets is not None and self._facets is None:
self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()}
return self._facets

@_ensure_response
def get_coverage(self):
self.continuation_token = None
return self._response.coverage

@_ensure_response
def get_count(self):
self.continuation_token = None
return self._response.count

@_ensure_response
def get_answers(self):
self.continuation_token = None
return self._response.answers
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,23 @@ async def _extract_data_cb(self, response): # pylint:disable=no-self-use

@_ensure_response
async def get_facets(self):
self.continuation_token = None
facets = self._response.facets
if facets is not None and self._facets is None:
self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()}
return self._facets

@_ensure_response
async def get_coverage(self):
self.continuation_token = None
return self._response.coverage

@_ensure_response
async def get_count(self):
self.continuation_token = None
return self._response.count

@_ensure_response
async def get_answers(self):
self.continuation_token = None
return self._response.answers
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
try:
from unittest import mock
except ImportError:
import mock
from azure.core.credentials import AzureKeyCredential
from azure.search.documents._generated.models import SearchDocumentsResult, SearchResult
from azure.search.documents.aio import SearchClient
from azure.search.documents.aio._search_client_async import AsyncSearchPageIterator

CREDENTIAL = AzureKeyCredential(key="test_api_key")

class TestSearchClientAsync(object):
@mock.patch(
"azure.search.documents._generated.aio.operations._documents_operations.DocumentsOperations.search_post"
)
async def test_get_count_reset_continuation_token(self, mock_search_post):
client = SearchClient("endpoint", "index name", CREDENTIAL)
result = await client.search(search_text="search text")
assert result._page_iterator_class is AsyncSearchPageIterator
search_result = SearchDocumentsResult()
search_result.results = [SearchResult(additional_properties={"key": "val"})]
mock_search_post.return_value = search_result
await result.__anext__()
result._first_page_iterator_instance.continuation_token = "fake token"
await result.get_count()
assert not result._first_page_iterator_instance.continuation_token
16 changes: 16 additions & 0 deletions sdk/search/azure-search-documents/tests/test_search_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,22 @@ def test_suggest_bad_argument(self):
repr("bad_query")
)

@mock.patch(
"azure.search.documents._generated.operations._documents_operations.DocumentsOperations.search_post"
)
def test_get_count_reset_continuation_token(self, mock_search_post):
client = SearchClient("endpoint", "index name", CREDENTIAL)
result = client.search(search_text="search text")
assert isinstance(result, ItemPaged)
assert result._page_iterator_class is SearchPageIterator
search_result = SearchDocumentsResult()
search_result.results = [SearchResult(additional_properties={"key": "val"})]
mock_search_post.return_value = search_result
result.__next__()
result._first_page_iterator_instance.continuation_token = "fake token"
result.get_count()
assert not result._first_page_iterator_instance.continuation_token

@mock.patch(
"azure.search.documents._generated.operations._documents_operations.DocumentsOperations.autocomplete_post"
)
Expand Down

0 comments on commit 20a3129

Please sign in to comment.