Skip to content

Commit

Permalink
feat: Allow queries using server side IN. (#954)
Browse files Browse the repository at this point in the history
* feat: Allow queries using server side IN.

* Rename force_server to server_op.
  • Loading branch information
sorced-jim authored Feb 28, 2024
1 parent 106772f commit 2646cef
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 25 deletions.
2 changes: 1 addition & 1 deletion google/cloud/ndb/_datastore_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
">": query_pb2.PropertyFilter.Operator.GREATER_THAN,
">=": query_pb2.PropertyFilter.Operator.GREATER_THAN_OR_EQUAL,
"!=": query_pb2.PropertyFilter.Operator.NOT_EQUAL,
"IN": query_pb2.PropertyFilter.Operator.IN,
"in": query_pb2.PropertyFilter.Operator.IN,
}

_KEY_NOT_IN_CACHE = object()
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/ndb/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def __ge__(self, value):
"""FilterNode: Represents the ``>=`` comparison."""
return self._comparison(">=", value)

def _IN(self, value):
def _IN(self, value, server_op=False):
"""For the ``in`` comparison operator.
The ``in`` operator cannot be overloaded in the way we want
Expand Down Expand Up @@ -1315,7 +1315,7 @@ def _IN(self, value):
sub_value = self._datastore_type(sub_value)
values.append(sub_value)

return query.FilterNode(self._name, "in", values)
return query.FilterNode(self._name, "in", values, server_op=server_op)

IN = _IN
"""Used to check if a property value is contained in a set of values.
Expand Down
18 changes: 4 additions & 14 deletions google/cloud/ndb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ class FilterNode(Node):
opsymbol (str): The comparison operator. One of ``=``, ``!=``, ``<``,
``<=``, ``>``, ``>=`` or ``in``.
value (Any): The value to filter on / relative to.
server_op (bool): Force the operator to use a server side filter.
Raises:
TypeError: If ``opsymbol`` is ``"in"`` but ``value`` is not a
Expand All @@ -630,7 +631,7 @@ class FilterNode(Node):
_opsymbol = None
_value = None

def __new__(cls, name, opsymbol, value):
def __new__(cls, name, opsymbol, value, server_op=False):
# Avoid circular import in Python 2.7
from google.cloud.ndb import model

Expand All @@ -648,7 +649,8 @@ def __new__(cls, name, opsymbol, value):
return FalseNode()
if len(nodes) == 1:
return nodes[0]
return DisjunctionNode(*nodes)
if not server_op:
return DisjunctionNode(*nodes)

instance = super(FilterNode, cls).__new__(cls)
instance._name = name
Expand Down Expand Up @@ -695,24 +697,12 @@ def _to_filter(self, post=False):
Optional[query_pb2.PropertyFilter]: Returns :data:`None`, if
this is a post-filter, otherwise returns the protocol buffer
representation of the filter.
Raises:
NotImplementedError: If the ``opsymbol`` is ``in``, since
they should correspond to a composite filter. This should
never occur since the constructor will create ``OR`` nodes for
``in``
"""
# Avoid circular import in Python 2.7
from google.cloud.ndb import _datastore_query

if post:
return None
if self._opsymbol in (_IN_OP):
raise NotImplementedError(
"Inequality filters are not single filter "
"expressions and therefore cannot be converted "
"to a single filter ({!r})".format(self._opsymbol)
)

return _datastore_query.make_filter(self._name, self._opsymbol, self._value)

Expand Down
34 changes: 34 additions & 0 deletions tests/system/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,40 @@ def make_entities():
assert not more


@pytest.mark.usefixtures("client_context")
def test_fetch_page_in_query(dispose_of):
page_size = 5
n_entities = page_size * 2

class SomeKind(ndb.Model):
foo = ndb.IntegerProperty()

@ndb.toplevel
def make_entities():
entities = [SomeKind(foo=n_entities) for i in range(n_entities)]
keys = yield [entity.put_async() for entity in entities]
raise ndb.Return(keys)

for key in make_entities():
dispose_of(key._key)

query = SomeKind.query().filter(SomeKind.foo.IN([1, 2, n_entities], server_op=True))
eventually(query.fetch, length_equals(n_entities))

results, cursor, more = query.fetch_page(page_size)
assert len(results) == page_size
assert more

safe_cursor = cursor.urlsafe()
next_cursor = ndb.Cursor(urlsafe=safe_cursor)
results, cursor, more = query.fetch_page(page_size, start_cursor=next_cursor)
assert len(results) == page_size

results, cursor, more = query.fetch_page(page_size, start_cursor=cursor)
assert not results
assert not more


@pytest.mark.usefixtures("client_context")
def test_polymodel_query(ds_entity):
class Animal(ndb.PolyModel):
Expand Down
29 changes: 28 additions & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def test__IN_wrong_container():
assert model.Property._FIND_METHODS_CACHE == {}

@staticmethod
def test__IN():
def test__IN_default():
prop = model.Property("name", indexed=True)
or_node = prop._IN(["a", None, "xy"])
expected = query_module.DisjunctionNode(
Expand All @@ -561,6 +561,33 @@ def test__IN():
# Also verify the alias
assert or_node == prop.IN(["a", None, "xy"])

@staticmethod
def test__IN_client():
prop = model.Property("name", indexed=True)
or_node = prop._IN(["a", None, "xy"], server_op=False)
expected = query_module.DisjunctionNode(
query_module.FilterNode("name", "=", "a"),
query_module.FilterNode("name", "=", None),
query_module.FilterNode("name", "=", "xy"),
)
assert or_node == expected
# Also verify the alias
assert or_node == prop.IN(["a", None, "xy"])

@staticmethod
def test_server__IN():
prop = model.Property("name", indexed=True)
in_node = prop._IN(["a", None, "xy"], server_op=True)
assert in_node == prop.IN(["a", None, "xy"], server_op=True)
assert in_node != query_module.DisjunctionNode(
query_module.FilterNode("name", "=", "a"),
query_module.FilterNode("name", "=", None),
query_module.FilterNode("name", "=", "xy"),
)
assert in_node == query_module.FilterNode(
"name", "in", ["a", None, "xy"], server_op=True
)

@staticmethod
def test___neg__():
prop = model.Property("name")
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,13 +701,6 @@ def test__to_ne_filter_op():
filter_node = query_module.FilterNode("speed", "!=", 88)
assert filter_node._to_filter(post=True) is None

@staticmethod
def test__to_filter_bad_op():
filter_node = query_module.FilterNode("speed", ">=", 88)
filter_node._opsymbol = "in"
with pytest.raises(NotImplementedError):
filter_node._to_filter()

@staticmethod
@mock.patch("google.cloud.ndb._datastore_query")
def test__to_filter(_datastore_query):
Expand Down

0 comments on commit 2646cef

Please sign in to comment.