Skip to content

refactor: add compile_aggregate #1904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions bigframes/core/compile/sqlglot/aggregate_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import functools
import typing

import sqlglot.expressions as sge

from bigframes.core import expression, window_spec
from bigframes.core.compile.sqlglot.expressions import typed_expr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
import bigframes.operations as ops


def compile_aggregate(
aggregate: expression.Aggregation,
order_by: tuple[sge.Expression, ...],
) -> sge.Expression:
"""Compiles BigFrames aggregation expression into SQLGlot expression."""
if isinstance(aggregate, expression.NullaryAggregation):
return compile_nullary_agg(aggregate.op)
if isinstance(aggregate, expression.UnaryAggregation):
column = typed_expr.TypedExpr(
scalar_compiler.compile_scalar_expression(aggregate.arg),
aggregate.arg.output_type,
)
if not aggregate.op.order_independent:
return compile_ordered_unary_agg(aggregate.op, column, order_by=order_by)
else:
return compile_unary_agg(aggregate.op, column)
elif isinstance(aggregate, expression.BinaryAggregation):
left = typed_expr.TypedExpr(
scalar_compiler.compile_scalar_expression(aggregate.left),
aggregate.left.output_type,
)
right = typed_expr.TypedExpr(
scalar_compiler.compile_scalar_expression(aggregate.right),
aggregate.right.output_type,
)
return compile_binary_agg(aggregate.op, left, right)
else:
raise ValueError(f"Unexpected aggregation: {aggregate}")


@functools.singledispatch
def compile_nullary_agg(
op: ops.aggregations.WindowOp,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
raise ValueError(f"Can't compile unrecognized operation: {op}")


@functools.singledispatch
def compile_binary_agg(
op: ops.aggregations.WindowOp,
left: typed_expr.TypedExpr,
right: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
raise ValueError(f"Can't compile unrecognized operation: {op}")


@functools.singledispatch
def compile_unary_agg(
op: ops.aggregations.WindowOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
raise ValueError(f"Can't compile unrecognized operation: {op}")


@functools.singledispatch
def compile_ordered_unary_agg(
op: ops.aggregations.WindowOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
order_by: typing.Sequence[sge.Expression] = [],
) -> sge.Expression:
raise ValueError(f"Can't compile unrecognized operation: {op}")


@compile_unary_agg.register
def _(
op: ops.aggregations.SumOp,
column: typed_expr.TypedExpr,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
# Will be null if all inputs are null. Pandas defaults to zero sum though.
expr = _apply_window_if_present(sge.func("SUM", column.expr), window)
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))


def _apply_window_if_present(
value: sge.Expression,
window: typing.Optional[window_spec.WindowSpec] = None,
) -> sge.Expression:
if window is not None:
raise NotImplementedError("Can't apply window to the expression.")
return value
34 changes: 33 additions & 1 deletion bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
from bigframes.core.compile import configs
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
from bigframes.core.compile.sqlglot.expressions import typed_expr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
Expand Down Expand Up @@ -217,7 +218,7 @@ def compile_filter(
self, node: nodes.FilterNode, child: ir.SQLGlotIR
) -> ir.SQLGlotIR:
condition = scalar_compiler.compile_scalar_expression(node.predicate)
return child.filter(condition)
return child.filter(tuple([condition]))

@_compile_node.register
def compile_join(
Expand Down Expand Up @@ -267,6 +268,37 @@ def compile_random_sample(
) -> ir.SQLGlotIR:
return child.sample(node.fraction)

@_compile_node.register
def compile_aggregate(
self, node: nodes.AggregateNode, child: ir.SQLGlotIR
) -> ir.SQLGlotIR:
ordering_cols = tuple(
sge.Ordered(
this=scalar_compiler.compile_scalar_expression(
ordering.scalar_expression
),
desc=ordering.direction.is_ascending is False,
nulls_first=ordering.na_last is False,
)
for ordering in node.order_by
)
aggregations: tuple[tuple[str, sge.Expression], ...] = tuple(
(id.sql, aggregate_compiler.compile_aggregate(agg, order_by=ordering_cols))
for agg, id in node.aggregations
)
by_cols: tuple[sge.Expression, ...] = tuple(
scalar_compiler.compile_scalar_expression(by_col)
for by_col in node.by_column_ids
)

dropna_cols = []
if node.dropna:
for key, by_col in zip(node.by_column_ids, by_cols):
if node.child.field_by_id[key.id].nullable:
dropna_cols.append(by_col)

return child.aggregate(aggregations, by_cols, tuple(dropna_cols))


def _replace_unsupported_ops(node: nodes.BigFrameNode):
node = nodes.bottom_up(node, rewrite.rewrite_slice)
Expand Down
73 changes: 64 additions & 9 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import dataclasses
import functools
import typing

from google.cloud import bigquery
Expand All @@ -25,11 +26,9 @@
import sqlglot.expressions as sge

from bigframes import dtypes
from bigframes.core import guid, utils
from bigframes.core import guid, local_data, schema, utils
from bigframes.core.compile.sqlglot.expressions import typed_expr
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
import bigframes.core.local_data as local_data
import bigframes.core.schema as bf_schema

# shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0.
try:
Expand Down Expand Up @@ -68,7 +67,7 @@ def sql(self) -> str:
def from_pyarrow(
cls,
pa_table: pa.Table,
schema: bf_schema.ArraySchema,
schema: schema.ArraySchema,
uid_gen: guid.SequentialUIDGenerator,
) -> SQLGlotIR:
"""Builds SQLGlot expression from a pyarrow table.
Expand Down Expand Up @@ -280,9 +279,13 @@ def limit(

def filter(
self,
condition: sge.Expression,
conditions: tuple[sge.Expression, ...],
) -> SQLGlotIR:
"""Filters the query by adding a WHERE clause."""
condition = _and(conditions)
if condition is None:
return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen)

new_expr = _select_to_cte(
self.expr,
sge.to_identifier(
Expand Down Expand Up @@ -316,10 +319,11 @@ def join(
right_ctes = right_select.args.pop("with", [])
merged_ctes = [*left_ctes, *right_ctes]

join_conditions = [
_join_condition(left, right, joins_nulls) for left, right in conditions
]
join_on = sge.And(expressions=join_conditions) if join_conditions else None
join_on = _and(
tuple(
_join_condition(left, right, joins_nulls) for left, right in conditions
)
)

join_type_str = join_type if join_type != "outer" else "full outer"
new_expr = (
Expand Down Expand Up @@ -364,6 +368,47 @@ def sample(self, fraction: float) -> SQLGlotIR:
).where(condition, append=False)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def aggregate(
self,
aggregations: tuple[tuple[str, sge.Expression], ...],
by_cols: tuple[sge.Expression, ...],
dropna_cols: tuple[sge.Expression, ...],
) -> SQLGlotIR:
"""Applies the aggregation expressions.

Args:
aggregations: output_column_id, aggregation_expr tuples
by_cols: column expressions for aggregation
dropna_cols: columns whether null keys should be dropped
"""
aggregations_expr = [
sge.Alias(
this=expr,
alias=sge.to_identifier(id, quoted=self.quoted),
)
for id, expr in aggregations
]

new_expr = _select_to_cte(
self.expr,
sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
),
)
new_expr = new_expr.group_by(*by_cols).select(
*[*by_cols, *aggregations_expr], append=False
)

condition = _and(
tuple(
sg.not_(sge.Is(this=drop_col, expression=sge.Null()))
for drop_col in dropna_cols
)
)
if condition is not None:
new_expr = new_expr.where(condition, append=False)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def insert(
self,
destination: bigquery.TableReference,
Expand Down Expand Up @@ -552,6 +597,16 @@ def _table(table: bigquery.TableReference) -> sge.Table:
)


def _and(conditions: tuple[sge.Expression, ...]) -> typing.Optional[sge.Expression]:
"""Chains multiple expressions together using a logical AND."""
if not conditions:
return None

return functools.reduce(
lambda left, right: sge.And(this=left, expression=right), conditions
)


def _join_condition(
left: typed_expr.TypedExpr,
right: typed_expr.TypedExpr,
Expand Down
46 changes: 46 additions & 0 deletions bigframes/core/rewrite/schema_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import dataclasses
import typing

from bigframes.core import bigframe_node
from bigframes.core import expression as ex
Expand Down Expand Up @@ -65,4 +66,49 @@ def bind_schema_to_node(
conditions=conditions,
)

if isinstance(node, nodes.AggregateNode):
aggregations = []
for aggregation, id in node.aggregations:
if isinstance(aggregation, ex.UnaryAggregation):
replaced = typing.cast(
ex.Aggregation,
dataclasses.replace(
aggregation,
arg=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(
aggregation.arg, node.child.field_by_id
),
),
),
)
aggregations.append((replaced, id))
elif isinstance(aggregation, ex.BinaryAggregation):
replaced = typing.cast(
ex.Aggregation,
dataclasses.replace(
aggregation,
left=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(
aggregation.left, node.child.field_by_id
),
),
right=typing.cast(
ex.RefOrConstant,
ex.bind_schema_fields(
aggregation.right, node.child.field_by_id
),
),
),
)
aggregations.append((replaced, id))
else:
aggregations.append((aggregation, id))

return dataclasses.replace(
node,
aggregations=tuple(aggregations),
)

return node
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
WITH `bfcte_0` AS (
SELECT
`bool_col` AS `bfcol_0`,
`int64_too` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
`bfcol_1` AS `bfcol_2`,
`bfcol_0` AS `bfcol_3`
FROM `bfcte_0`
), `bfcte_2` AS (
SELECT
`bfcol_3`,
COALESCE(SUM(`bfcol_2`), 0) AS `bfcol_6`
FROM `bfcte_1`
WHERE
NOT `bfcol_3` IS NULL
GROUP BY
`bfcol_3`
)
SELECT
`bfcol_3` AS `bool_col`,
`bfcol_6` AS `int64_too`
FROM `bfcte_2`
ORDER BY
`bfcol_3` ASC NULLS LAST
Loading