Skip to content

Commit 8e33c0e

Browse files
refactor: removing usage of GetDocument proto in favor of BatchGetDocuments (#316)
* Began refactoring away from GetDocument proto which contains no `read_time` field and in general is a shortcut around `BatchGetDocuments` * Correctly instantiate snapshots for missing documents * Removed stale NotFound exception * Removed unnecessary empty list check * Linting fix * Expanded batch_get change to async classes * Updated variable name * Added get_batch to async classes * Improved consumption of async generators * Fixed test coverage * Fixed broken mock in test * Linting * Reverted the move of AsyncIter Co-authored-by: Craig Labenz <craiglabenz@google.com>
1 parent 18a2a8a commit 8e33c0e

File tree

11 files changed

+175
-102
lines changed

11 files changed

+175
-102
lines changed

packages/google-cloud-firestore/google/cloud/firestore_v1/async_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from google.cloud.firestore_v1.services.firestore.transports import (
5050
grpc_asyncio as firestore_grpc_transport,
5151
)
52-
from typing import Any, AsyncGenerator, Iterable
52+
from typing import Any, AsyncGenerator, Iterable, List
5353

5454

5555
class AsyncClient(BaseClient):
@@ -209,7 +209,7 @@ def document(self, *document_path: str) -> AsyncDocumentReference:
209209

210210
async def get_all(
211211
self,
212-
references: list,
212+
references: List[AsyncDocumentReference],
213213
field_paths: Iterable[str] = None,
214214
transaction=None,
215215
retry: retries.Retry = gapic_v1.method.DEFAULT,

packages/google-cloud-firestore/google/cloud/firestore_v1/async_document.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,27 @@
1313
# limitations under the License.
1414

1515
"""Classes for representing documents for the Google Cloud Firestore API."""
16+
import datetime
17+
import logging
1618

1719
from google.api_core import gapic_v1 # type: ignore
1820
from google.api_core import retry as retries # type: ignore
21+
from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore
1922

2023
from google.cloud.firestore_v1.base_document import (
2124
BaseDocumentReference,
2225
DocumentSnapshot,
2326
_first_write_result,
2427
)
25-
26-
from google.api_core import exceptions # type: ignore
2728
from google.cloud.firestore_v1 import _helpers
2829
from google.cloud.firestore_v1.types import write
29-
from google.protobuf import timestamp_pb2
30+
from google.protobuf.timestamp_pb2 import Timestamp
3031
from typing import Any, AsyncGenerator, Coroutine, Iterable, Union
3132

3233

34+
logger = logging.getLogger(__name__)
35+
36+
3337
class AsyncDocumentReference(BaseDocumentReference):
3438
"""A reference to a document in a Firestore database.
3539
@@ -289,7 +293,7 @@ async def delete(
289293
option: _helpers.WriteOption = None,
290294
retry: retries.Retry = gapic_v1.method.DEFAULT,
291295
timeout: float = None,
292-
) -> timestamp_pb2.Timestamp:
296+
) -> Timestamp:
293297
"""Delete the current document in the Firestore database.
294298
295299
Args:
@@ -353,31 +357,34 @@ async def get(
353357
:attr:`create_time` attributes will all be ``None`` and
354358
its :attr:`exists` attribute will be ``False``.
355359
"""
356-
request, kwargs = self._prep_get(field_paths, transaction, retry, timeout)
360+
from google.cloud.firestore_v1.base_client import _parse_batch_get
357361

358-
firestore_api = self._client._firestore_api
359-
try:
360-
document_pb = await firestore_api.get_document(
361-
request=request, metadata=self._client._rpc_metadata, **kwargs,
362+
request, kwargs = self._prep_batch_get(field_paths, transaction, retry, timeout)
363+
364+
response_iter = await self._client._firestore_api.batch_get_documents(
365+
request=request, metadata=self._client._rpc_metadata, **kwargs,
366+
)
367+
368+
async for resp in response_iter:
369+
# Immediate return as the iterator should only ever have one item.
370+
return _parse_batch_get(
371+
get_doc_response=resp,
372+
reference_map={self._document_path: self},
373+
client=self._client,
362374
)
363-
except exceptions.NotFound:
364-
data = None
365-
exists = False
366-
create_time = None
367-
update_time = None
368-
else:
369-
data = _helpers.decode_dict(document_pb.fields, self._client)
370-
exists = True
371-
create_time = document_pb.create_time
372-
update_time = document_pb.update_time
375+
376+
logger.warning(
377+
"`batch_get_documents` unexpectedly returned empty "
378+
"stream. Expected one object.",
379+
)
373380

374381
return DocumentSnapshot(
375-
reference=self,
376-
data=data,
377-
exists=exists,
378-
read_time=None, # No server read_time available
379-
create_time=create_time,
380-
update_time=update_time,
382+
self,
383+
None,
384+
exists=False,
385+
read_time=_datetime_to_pb_timestamp(datetime.datetime.now()),
386+
create_time=None,
387+
update_time=None,
381388
)
382389

383390
async def collections(

packages/google-cloud-firestore/google/cloud/firestore_v1/base_document.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def delete(
268268
) -> NoReturn:
269269
raise NotImplementedError
270270

271-
def _prep_get(
271+
def _prep_batch_get(
272272
self,
273273
field_paths: Iterable[str] = None,
274274
transaction=None,
@@ -285,7 +285,8 @@ def _prep_get(
285285
mask = None
286286

287287
request = {
288-
"name": self._document_path,
288+
"database": self._client._database_string,
289+
"documents": [self._document_path],
289290
"mask": mask,
290291
"transaction": _helpers.get_transaction_id(transaction),
291292
}

packages/google-cloud-firestore/google/cloud/firestore_v1/document.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,28 @@
1313
# limitations under the License.
1414

1515
"""Classes for representing documents for the Google Cloud Firestore API."""
16+
import datetime
17+
import logging
1618

1719
from google.api_core import gapic_v1 # type: ignore
1820
from google.api_core import retry as retries # type: ignore
21+
from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore
1922

2023
from google.cloud.firestore_v1.base_document import (
2124
BaseDocumentReference,
2225
DocumentSnapshot,
2326
_first_write_result,
2427
)
25-
26-
from google.api_core import exceptions # type: ignore
2728
from google.cloud.firestore_v1 import _helpers
2829
from google.cloud.firestore_v1.types import write
2930
from google.cloud.firestore_v1.watch import Watch
30-
from google.protobuf import timestamp_pb2
31+
from google.protobuf.timestamp_pb2 import Timestamp
3132
from typing import Any, Callable, Generator, Iterable
3233

3334

35+
logger = logging.getLogger(__name__)
36+
37+
3438
class DocumentReference(BaseDocumentReference):
3539
"""A reference to a document in a Firestore database.
3640
@@ -325,7 +329,7 @@ def delete(
325329
option: _helpers.WriteOption = None,
326330
retry: retries.Retry = gapic_v1.method.DEFAULT,
327331
timeout: float = None,
328-
) -> timestamp_pb2.Timestamp:
332+
) -> Timestamp:
329333
"""Delete the current document in the Firestore database.
330334
331335
Args:
@@ -389,31 +393,35 @@ def get(
389393
:attr:`create_time` attributes will all be ``None`` and
390394
its :attr:`exists` attribute will be ``False``.
391395
"""
392-
request, kwargs = self._prep_get(field_paths, transaction, retry, timeout)
396+
from google.cloud.firestore_v1.base_client import _parse_batch_get
397+
398+
request, kwargs = self._prep_batch_get(field_paths, transaction, retry, timeout)
399+
400+
response_iter = self._client._firestore_api.batch_get_documents(
401+
request=request, metadata=self._client._rpc_metadata, **kwargs,
402+
)
403+
404+
get_doc_response = next(response_iter, None)
393405

394-
firestore_api = self._client._firestore_api
395-
try:
396-
document_pb = firestore_api.get_document(
397-
request=request, metadata=self._client._rpc_metadata, **kwargs,
406+
if get_doc_response is not None:
407+
return _parse_batch_get(
408+
get_doc_response=get_doc_response,
409+
reference_map={self._document_path: self},
410+
client=self._client,
398411
)
399-
except exceptions.NotFound:
400-
data = None
401-
exists = False
402-
create_time = None
403-
update_time = None
404-
else:
405-
data = _helpers.decode_dict(document_pb.fields, self._client)
406-
exists = True
407-
create_time = document_pb.create_time
408-
update_time = document_pb.update_time
412+
413+
logger.warning(
414+
"`batch_get_documents` unexpectedly returned empty "
415+
"stream. Expected one object.",
416+
)
409417

410418
return DocumentSnapshot(
411-
reference=self,
412-
data=data,
413-
exists=exists,
414-
read_time=None, # No server read_time available
415-
create_time=create_time,
416-
update_time=update_time,
419+
self,
420+
None,
421+
exists=False,
422+
read_time=_datetime_to_pb_timestamp(datetime.datetime.now()),
423+
create_time=None,
424+
update_time=None,
417425
)
418426

419427
def collections(

packages/google-cloud-firestore/tests/unit/v1/test__helpers.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import aiounittest
1617
import datetime
1718
import sys
1819
import unittest
1920

2021
import mock
22+
import pytest
23+
from typing import List
2124

2225

2326
class AsyncMock(mock.MagicMock):
@@ -26,10 +29,14 @@ async def __call__(self, *args, **kwargs):
2629

2730

2831
class AsyncIter:
32+
"""Utility to help recreate the effect of an async generator. Useful when
33+
you need to mock a system that requires `async for`.
34+
"""
35+
2936
def __init__(self, items):
3037
self.items = items
3138

32-
async def __aiter__(self, **_):
39+
async def __aiter__(self):
3340
for i in self.items:
3441
yield i
3542

@@ -2424,6 +2431,15 @@ def test_retry_and_timeout(self):
24242431
self.assertEqual(kwargs, expected)
24252432

24262433

2434+
class TestAsyncGenerator(aiounittest.AsyncTestCase):
2435+
@pytest.mark.asyncio
2436+
async def test_async_iter(self):
2437+
consumed: List[int] = []
2438+
async for el in AsyncIter([1, 2, 3]):
2439+
consumed.append(el)
2440+
self.assertEqual(consumed, [1, 2, 3])
2441+
2442+
24272443
def _value_pb(**kwargs):
24282444
from google.cloud.firestore_v1.types.document import Value
24292445

packages/google-cloud-firestore/tests/unit/v1/test_async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import aiounittest
1919

2020
import mock
21-
from tests.unit.v1.test__helpers import AsyncMock, AsyncIter
21+
from tests.unit.v1.test__helpers import AsyncIter, AsyncMock
2222

2323

2424
class TestAsyncClient(aiounittest.AsyncTestCase):

packages/google-cloud-firestore/tests/unit/v1/test_async_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import aiounittest
1818

1919
import mock
20-
from tests.unit.v1.test__helpers import AsyncMock, AsyncIter
20+
from tests.unit.v1.test__helpers import AsyncIter, AsyncMock
2121

2222

2323
class TestAsyncCollectionReference(aiounittest.AsyncTestCase):

0 commit comments

Comments
 (0)