Skip to content

Commit

Permalink
Add 'Client.collections' method. (#6650)
Browse files Browse the repository at this point in the history
Lists top-level collections in the client's database.

Closes #6553.
  • Loading branch information
tseaver authored Nov 29, 2018
1 parent fd25494 commit d8a3ec3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
24 changes: 24 additions & 0 deletions firestore/google/cloud/firestore_v1beta1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,19 @@ def get_all(self, references, field_paths=None, transaction=None):
for get_doc_response in response_iterator:
yield _parse_batch_get(get_doc_response, reference_map, self)

def collections(self):
"""List top-level collections of the client's database.
Returns:
Sequence[~.firestore_v1beta1.collection.CollectionReference]:
iterator of subcollections of the current document.
"""
iterator = self._firestore_api.list_collection_ids(
self._database_string, metadata=self._rpc_metadata)
iterator.client = self
iterator.item_to_value = _item_to_collection_ref
return iterator

def batch(self):
"""Get a batch instance from this client.
Expand Down Expand Up @@ -477,3 +490,14 @@ def _get_doc_mask(field_paths):
return None
else:
return types.DocumentMask(field_paths=field_paths)


def _item_to_collection_ref(iterator, item):
"""Convert collection ID to collection ref.
Args:
iterator (google.api_core.page_iterator.GRPCIterator):
iterator response
item (str): ID of the collection
"""
return iterator.client.collection(item)
38 changes: 38 additions & 0 deletions firestore/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,44 @@ def test_write_bad_arg(self):
extra = '{!r} was provided'.format('spinach')
self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR, extra))

def test_collections(self):
from google.api_core.page_iterator import Iterator
from google.api_core.page_iterator import Page
from google.cloud.firestore_v1beta1.collection import (
CollectionReference)

collection_ids = ['users', 'projects']
client = self._make_default_one()
firestore_api = mock.Mock(spec=['list_collection_ids'])
client._firestore_api_internal = firestore_api

class _Iterator(Iterator):

def __init__(self, pages):
super(_Iterator, self).__init__(client=None)
self._pages = pages

def _next_page(self):
if self._pages:
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)

iterator = _Iterator(pages=[collection_ids])
firestore_api.list_collection_ids.return_value = iterator

collections = list(client.collections())

self.assertEqual(len(collections), len(collection_ids))
for collection, collection_id in zip(collections, collection_ids):
self.assertIsInstance(collection, CollectionReference)
self.assertEqual(collection.parent, None)
self.assertEqual(collection.id, collection_id)

firestore_api.list_collection_ids.assert_called_once_with(
client._database_string,
metadata=client._rpc_metadata,
)

def _get_all_helper(self, client, references, document_pbs, **kwargs):
# Create a minimal fake GAPIC with a dummy response.
firestore_api = mock.Mock(spec=['batch_get_documents'])
Expand Down

0 comments on commit d8a3ec3

Please sign in to comment.