Skip to content

Commit 1c45ccb

Browse files
feat: Support local execution of comparison ops (#1849)
1 parent 633bf98 commit 1c45ccb

File tree

4 files changed

+151
-3
lines changed

4 files changed

+151
-3
lines changed

bigframes/core/compile/polars/lowering.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,31 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import dataclasses
16+
1517
from bigframes import dtypes
1618
from bigframes.core import bigframe_node, expression
1719
from bigframes.core.rewrite import op_lowering
18-
from bigframes.operations import numeric_ops
20+
from bigframes.operations import comparison_ops, numeric_ops
1921
import bigframes.operations as ops
2022

2123
# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
2224

2325

26+
@dataclasses.dataclass
27+
class CoerceArgsRule(op_lowering.OpLoweringRule):
28+
op_type: type[ops.BinaryOp]
29+
30+
@property
31+
def op(self) -> type[ops.ScalarOp]:
32+
return self.op_type
33+
34+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
35+
assert isinstance(expr.op, self.op_type)
36+
larg, rarg = _coerce_comparables(expr.children[0], expr.children[1])
37+
return expr.op.as_expr(larg, rarg)
38+
39+
2440
class LowerFloorDivRule(op_lowering.OpLoweringRule):
2541
@property
2642
def op(self) -> type[ops.ScalarOp]:
@@ -40,7 +56,42 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
4056
return ops.where_op.as_expr(zero_result, divisor_is_zero, expr)
4157

4258

43-
POLARS_LOWERING_RULES = (LowerFloorDivRule(),)
59+
def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression):
60+
61+
target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type)
62+
if expr1.output_type != target_type:
63+
expr1 = _lower_cast(ops.AsTypeOp(target_type), expr1)
64+
if expr2.output_type != target_type:
65+
expr2 = _lower_cast(ops.AsTypeOp(target_type), expr2)
66+
return expr1, expr2
67+
68+
69+
# TODO: Need to handle bool->string cast to get capitalization correct
70+
def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
71+
if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type):
72+
# bool -> decimal needs two-step cast
73+
new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg)
74+
return cast_op.as_expr(new_arg)
75+
return cast_op.as_expr(arg)
76+
77+
78+
LOWER_COMPARISONS = tuple(
79+
CoerceArgsRule(op)
80+
for op in (
81+
comparison_ops.EqOp,
82+
comparison_ops.EqNullsMatchOp,
83+
comparison_ops.NeOp,
84+
comparison_ops.LtOp,
85+
comparison_ops.GtOp,
86+
comparison_ops.LeOp,
87+
comparison_ops.GeOp,
88+
)
89+
)
90+
91+
POLARS_LOWERING_RULES = (
92+
*LOWER_COMPARISONS,
93+
LowerFloorDivRule(),
94+
)
4495

4596

4697
def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:

bigframes/core/compile/scalar_op_compiler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,7 @@ def eq_op(
14981498
x: ibis_types.Value,
14991499
y: ibis_types.Value,
15001500
):
1501+
x, y = _coerce_comparables(x, y)
15011502
return x == y
15021503

15031504

@@ -1507,6 +1508,7 @@ def eq_nulls_match_op(
15071508
y: ibis_types.Value,
15081509
):
15091510
"""Variant of eq_op where nulls match each other. Only use where dtypes are known to be same."""
1511+
x, y = _coerce_comparables(x, y)
15101512
literal = ibis_types.literal("$NULL_SENTINEL$")
15111513
if hasattr(x, "fill_null"):
15121514
left = x.cast(ibis_dtypes.str).fill_null(literal)
@@ -1523,6 +1525,7 @@ def ne_op(
15231525
x: ibis_types.Value,
15241526
y: ibis_types.Value,
15251527
):
1528+
x, y = _coerce_comparables(x, y)
15261529
return x != y
15271530

15281531

@@ -1534,6 +1537,17 @@ def _null_or_value(value: ibis_types.Value, where_value: ibis_types.BooleanValue
15341537
)
15351538

15361539

1540+
def _coerce_comparables(
1541+
x: ibis_types.Value,
1542+
y: ibis_types.Value,
1543+
):
1544+
if x.type().is_boolean() and not y.type().is_boolean():
1545+
x = x.cast(ibis_dtypes.int64)
1546+
elif y.type().is_boolean() and not x.type().is_boolean():
1547+
y = y.cast(ibis_dtypes.int64)
1548+
return x, y
1549+
1550+
15371551
@scalar_op_compiler.register_binary_op(ops.and_op)
15381552
def and_op(
15391553
x: ibis_types.Value,
@@ -1735,6 +1749,7 @@ def lt_op(
17351749
x: ibis_types.Value,
17361750
y: ibis_types.Value,
17371751
):
1752+
x, y = _coerce_comparables(x, y)
17381753
return x < y
17391754

17401755

@@ -1744,6 +1759,7 @@ def le_op(
17441759
x: ibis_types.Value,
17451760
y: ibis_types.Value,
17461761
):
1762+
x, y = _coerce_comparables(x, y)
17471763
return x <= y
17481764

17491765

@@ -1753,6 +1769,7 @@ def gt_op(
17531769
x: ibis_types.Value,
17541770
y: ibis_types.Value,
17551771
):
1772+
x, y = _coerce_comparables(x, y)
17561773
return x > y
17571774

17581775

@@ -1762,6 +1779,7 @@ def ge_op(
17621779
x: ibis_types.Value,
17631780
y: ibis_types.Value,
17641781
):
1782+
x, y = _coerce_comparables(x, y)
17651783
return x >= y
17661784

17671785

bigframes/session/polars_executor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,20 @@
3232
nodes.OrderByNode,
3333
nodes.ReversedNode,
3434
nodes.SelectionNode,
35+
nodes.ProjectionNode,
3536
nodes.SliceNode,
3637
nodes.AggregateNode,
3738
)
3839

39-
_COMPATIBLE_SCALAR_OPS = ()
40+
_COMPATIBLE_SCALAR_OPS = (
41+
bigframes.operations.eq_op,
42+
bigframes.operations.eq_null_match_op,
43+
bigframes.operations.ne_op,
44+
bigframes.operations.gt_op,
45+
bigframes.operations.lt_op,
46+
bigframes.operations.ge_op,
47+
bigframes.operations.le_op,
48+
)
4049
_COMPATIBLE_AGG_OPS = (agg_ops.SizeOp, agg_ops.SizeUnaryOp)
4150

4251

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
17+
import pytest
18+
19+
from bigframes.core import array_value
20+
import bigframes.operations as ops
21+
from bigframes.session import polars_executor
22+
from bigframes.testing.engine_utils import assert_equivalence_execution
23+
24+
pytest.importorskip("polars")
25+
26+
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
27+
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
28+
29+
# numeric domain
30+
31+
32+
def apply_op_pairwise(
33+
array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[]
34+
) -> array_value.ArrayValue:
35+
exprs = []
36+
for l_arg, r_arg in itertools.permutations(array.column_ids, 2):
37+
if (l_arg in excluded_cols) or (r_arg in excluded_cols):
38+
continue
39+
try:
40+
_ = op.output_type(
41+
array.get_column_type(l_arg), array.get_column_type(r_arg)
42+
)
43+
exprs.append(op.as_expr(l_arg, r_arg))
44+
except TypeError:
45+
continue
46+
assert len(exprs) > 0
47+
new_arr, _ = array.compute_values(exprs)
48+
return new_arr
49+
50+
51+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
52+
@pytest.mark.parametrize(
53+
"op",
54+
[
55+
ops.eq_op,
56+
ops.eq_null_match_op,
57+
ops.ne_op,
58+
ops.gt_op,
59+
ops.lt_op,
60+
ops.le_op,
61+
ops.ge_op,
62+
],
63+
)
64+
def test_engines_project_comparison_op(
65+
scalars_array_value: array_value.ArrayValue, engine, op
66+
):
67+
# exclude string cols as does not contain dates
68+
# bool col actually doesn't work properly for bq engine
69+
arr = apply_op_pairwise(scalars_array_value, op, excluded_cols=["string_col"])
70+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)