Skip to content

Commit 84e344e

Browse files
committed
fixed tests
1 parent 8147f72 commit 84e344e

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

tests/unit/v1/test_base_query.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,7 +2120,9 @@ def test__query_pipeline_cursors():
21202120
from google.cloud.firestore_v1 import pipeline_expressions as expr
21212121

21222122
client = make_client()
2123-
query_start = client.collection("my_col").order_by("field_a").start_at({"field_a": 10})
2123+
query_start = (
2124+
client.collection("my_col").order_by("field_a").start_at({"field_a": 10})
2125+
)
21242126
pipeline = query_start._build_pipeline(client.pipeline())
21252127

21262128
# Stages:
@@ -2129,7 +2131,7 @@ def test__query_pipeline_cursors():
21292131
# 2: Where (cursor condition)
21302132
# 3: Sort (field_a)
21312133
assert len(pipeline.stages) == 4
2132-
2134+
21332135
where_stage = pipeline.stages[2]
21342136
assert isinstance(where_stage, stages.Where)
21352137
# Expected: (field_a > 10) OR (field_a == 10)
@@ -2336,6 +2338,7 @@ def _make_snapshot(docref, values):
23362338

23372339
return document.DocumentSnapshot(docref, values, True, None, None, None)
23382340

2341+
23392342
def test__where_conditions_from_cursor_descending():
23402343
from google.cloud.firestore_v1.base_query import _where_conditions_from_cursor
23412344
from google.cloud.firestore_v1 import pipeline_expressions
@@ -2348,10 +2351,7 @@ def test__where_conditions_from_cursor_descending():
23482351
cursor = ([10], True)
23492352
condition = _where_conditions_from_cursor(cursor, [ordering], is_start_cursor=True)
23502353
# Expected: field < 10 OR field == 10
2351-
expected = pipeline_expressions.Or(
2352-
field_expr.less_than(10),
2353-
field_expr.equal(10)
2354-
)
2354+
expected = pipeline_expressions.Or(field_expr.less_than(10), field_expr.equal(10))
23552355
assert condition == expected
23562356

23572357
# Case 2: StartAfter (exclusive) -> < 10
@@ -2366,8 +2366,7 @@ def test__where_conditions_from_cursor_descending():
23662366
condition = _where_conditions_from_cursor(cursor, [ordering], is_start_cursor=False)
23672367
# Expected: field > 10 OR field == 10
23682368
expected = pipeline_expressions.Or(
2369-
field_expr.greater_than(10),
2370-
field_expr.equal(10)
2369+
field_expr.greater_than(10), field_expr.equal(10)
23712370
)
23722371
assert condition == expected
23732372

tests/unit/v1/test_pipeline_expressions.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,58 +463,72 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client):
463463
BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client)
464464

465465
@pytest.mark.parametrize(
466-
"op_enum, value, expected_expr_func",
466+
"op_enum, value, expected_expr_func, expects_existance",
467467
[
468468
(
469469
query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN,
470470
10,
471471
Expression.less_than,
472+
True,
472473
),
473474
(
474475
query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL,
475476
10,
476477
Expression.less_than_or_equal,
478+
True,
477479
),
478480
(
479481
query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN,
480482
10,
481483
Expression.greater_than,
484+
True,
482485
),
483486
(
484487
query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL,
485488
10,
486489
Expression.greater_than_or_equal,
490+
True,
491+
),
492+
(
493+
query_pb.StructuredQuery.FieldFilter.Operator.EQUAL,
494+
10,
495+
Expression.equal,
496+
True,
487497
),
488-
(query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, Expression.equal),
489498
(
490499
query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL,
491500
10,
492501
Expression.not_equal,
502+
False,
493503
),
494504
(
495505
query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS,
496506
10,
497507
Expression.array_contains,
508+
True,
498509
),
499510
(
500511
query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY,
501512
[10, 20],
502513
Expression.array_contains_any,
514+
True,
503515
),
504516
(
505517
query_pb.StructuredQuery.FieldFilter.Operator.IN,
506518
[10, 20],
507519
Expression.equal_any,
520+
True,
508521
),
509522
(
510523
query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN,
511524
[10, 20],
512525
Expression.not_equal_any,
526+
False,
513527
),
514528
],
515529
)
516530
def test__from_query_filter_pb_field_filter(
517-
self, mock_client, op_enum, value, expected_expr_func
531+
self, mock_client, op_enum, value, expected_expr_func, expects_existance
518532
):
519533
"""
520534
test supported field filters
@@ -536,10 +550,11 @@ def test__from_query_filter_pb_field_filter(
536550
[Constant(e) for e in value] if isinstance(value, list) else Constant(value)
537551
)
538552
expected_condition = expected_expr_func(field_expr, value)
539-
# should include existance checks
540-
expected = expr.And(field_expr.exists(), expected_condition)
553+
if expects_existance:
554+
# some expressions include extra existance checks
555+
expected_condition = expr.And(field_expr.exists(), expected_condition)
541556

542-
assert repr(result) == repr(expected)
557+
assert repr(result) == repr(expected_condition)
543558

544559
def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client):
545560
"""

0 commit comments

Comments
 (0)