Skip to content

refactor: add uid generator and encasualate query as cte in SQLGlotCompiler #1679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ repos:
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude: "^tests/unit/core/compile/sqlglot/snapshots"
- id: check-yaml
- repo: https://github.com/pycqa/isort
rev: 5.12.0
Expand Down
111 changes: 57 additions & 54 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,28 @@

import dataclasses
import functools
import itertools
import typing

from google.cloud import bigquery
import pyarrow as pa
import sqlglot.expressions as sge

from bigframes.core import expression, identifiers, nodes, rewrite
from bigframes.core import expression, guid, identifiers, nodes, rewrite
from bigframes.core.compile import configs
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
import bigframes.core.ordering as bf_ordering


@dataclasses.dataclass(frozen=True)
class SQLGlotCompiler:
"""Compiles BigFrame nodes into SQL using SQLGlot."""

uid_gen: guid.SequentialUIDGenerator
"""Generator for unique identifiers."""

def __init__(self):
self.uid_gen = guid.SequentialUIDGenerator()

def compile(
self,
node: nodes.BigFrameNode,
Expand Down Expand Up @@ -82,7 +86,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
result_node = typing.cast(
nodes.ResultNode, rewrite.column_pruning(result_node)
)
result_node = _remap_variables(result_node)
result_node = self._remap_variables(result_node)
sql = self._compile_result_node(result_node)
return configs.CompileResult(
sql, result_node.schema.to_bigquery(), result_node.order_by
Expand All @@ -92,7 +96,7 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
result_node = dataclasses.replace(result_node, order_by=None)
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))

result_node = _remap_variables(result_node)
result_node = self._remap_variables(result_node)
sql = self._compile_result_node(result_node)
# Return the ordering iff no extra columns are needed to define the row order
if ordering is not None:
Expand All @@ -106,63 +110,62 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
sql, result_node.schema.to_bigquery(), output_order
)

def _remap_variables(self, node: nodes.ResultNode) -> nodes.ResultNode:
"""Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""

result_node, _ = rewrite.remap_variables(
node, map(identifiers.ColumnId, self.uid_gen.get_uid_stream("bfcol_"))
)
return typing.cast(nodes.ResultNode, result_node)

def _compile_result_node(self, root: nodes.ResultNode) -> str:
sqlglot_ir = compile_node(root.child)
sqlglot_ir = self.compile_node(root.child)
# TODO: add order_by, limit, and selections to sqlglot_expr
return sqlglot_ir.sql

@functools.lru_cache(maxsize=5000)
def compile_node(self, node: nodes.BigFrameNode) -> ir.SQLGlotIR:
"""Compiles node into CompileArrayValue. Caches result."""
return node.reduce_up(
lambda node, children: self._compile_node(node, *children)
)

def _replace_unsupported_ops(node: nodes.BigFrameNode):
node = nodes.bottom_up(node, rewrite.rewrite_slice)
node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions)
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
return node


def _remap_variables(node: nodes.ResultNode) -> nodes.ResultNode:
"""Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs."""

def anonymous_column_ids() -> typing.Generator[identifiers.ColumnId, None, None]:
for i in itertools.count():
yield identifiers.ColumnId(name=f"bfcol_{i}")

result_node, _ = rewrite.remap_variables(node, anonymous_column_ids())
return typing.cast(nodes.ResultNode, result_node)


@functools.lru_cache(maxsize=5000)
def compile_node(node: nodes.BigFrameNode) -> ir.SQLGlotIR:
"""Compiles node into CompileArrayValue. Caches result."""
return node.reduce_up(lambda node, children: _compile_node(node, *children))


@functools.singledispatch
def _compile_node(
node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
) -> ir.SQLGlotIR:
"""Defines transformation but isn't cached, always use compile_node instead"""
raise ValueError(f"Can't compile unrecognized node: {node}")
@functools.singledispatchmethod
def _compile_node(
self, node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR
) -> ir.SQLGlotIR:
"""Defines transformation but isn't cached, always use compile_node instead"""
raise ValueError(f"Can't compile unrecognized node: {node}")

@_compile_node.register
def compile_readlocal(self, node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
pa_table = node.local_data_source.data
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
pa_table = pa_table.rename_columns(
[item.id.sql for item in node.scan_list.items]
)

offsets = node.offsets_col.sql if node.offsets_col else None
if offsets:
pa_table = pa_table.append_column(
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
)

@_compile_node.register
def compile_readlocal(node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR:
pa_table = node.local_data_source.data
pa_table = pa_table.select([item.source_id for item in node.scan_list.items])
pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items])
return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=self.uid_gen)

offsets = node.offsets_col.sql if node.offsets_col else None
if offsets:
pa_table = pa_table.append_column(
offsets, pa.array(range(pa_table.num_rows), type=pa.int64())
@_compile_node.register
def compile_selection(
self, node: nodes.SelectionNode, child: ir.SQLGlotIR
) -> ir.SQLGlotIR:
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
(id.sql, scalar_compiler.compile_scalar_expression(expr))
for expr, id in node.input_output_pairs
)
return child.select(selected_cols)

return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema)


@_compile_node.register
def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
select_cols: typing.Dict[str, sge.Expression] = {
id.name: scalar_compiler.compile_scalar_expression(expr)
for expr, id in node.input_output_pairs
}
return child.select(select_cols)
def _replace_unsupported_ops(node: nodes.BigFrameNode):
node = nodes.bottom_up(node, rewrite.rewrite_slice)
node = nodes.bottom_up(node, rewrite.rewrite_timedelta_expressions)
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
return node
44 changes: 37 additions & 7 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sqlglot.expressions as sge

from bigframes import dtypes
from bigframes.core import guid
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
import bigframes.core.local_data as local_data
import bigframes.core.schema as schemata
Expand Down Expand Up @@ -52,14 +53,20 @@ class SQLGlotIR:
pretty: bool = True
"""Whether to pretty-print the generated SQL."""

uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
"""Generator for unique identifiers."""

@property
def sql(self) -> str:
"""Generate SQL string from the given expression."""
return self.expr.sql(dialect=self.dialect, pretty=self.pretty)

@classmethod
def from_pyarrow(
cls, pa_table: pa.Table, schema: schemata.ArraySchema
cls,
pa_table: pa.Table,
schema: schemata.ArraySchema,
uid_gen: guid.SequentialUIDGenerator,
) -> SQLGlotIR:
"""Builds SQLGlot expression from pyarrow table."""
dtype_expr = sge.DataType(
Expand Down Expand Up @@ -95,21 +102,44 @@ def from_pyarrow(
),
],
)
return cls(expr=sg.select(sge.Star()).from_(expr))
return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen)

def select(
self,
select_cols: typing.Dict[str, sge.Expression],
selected_cols: tuple[tuple[str, sge.Expression], ...],
) -> SQLGlotIR:
selected_cols = [
cols_expr = [
sge.Alias(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always want to alias, even if the ids don't change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the IDs remain unchanged, we could omit aliasing (e.g., col1 as col1) to shorten the SQL, although it's grammatically valid. However, our golden tests don't see this case, but use aliases like col0 as col9, where the expression is ColumnDef(col0) and the ID is ColumnId(col9). Therefore, to optimize SQL length, we should first address nodes.SelectionNode. IIUC, I can create a bug ticket to revisit this for SQL length optimization later.

this=expr,
alias=sge.to_identifier(id, quoted=self.quoted),
)
for id, expr in select_cols.items()
for id, expr in selected_cols
]
expr = self.expr.select(*selected_cols, append=False)
return SQLGlotIR(expr=expr)
new_expr = self._encapsulate_as_cte().select(*cols_expr, append=False)
return SQLGlotIR(expr=new_expr)

def _encapsulate_as_cte(
self,
) -> sge.Select:
"""Transforms a given sge.Select query by pushing its main SELECT statement
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
for the new query."""
select_expr = self.expr.copy()

existing_ctes = select_expr.args.pop("with", [])
new_cte_name = sge.to_identifier(
next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted
)
new_cte = sge.CTE(
this=select_expr,
alias=new_cte_name,
)
new_with_clause = sge.With(expressions=existing_ctes + [new_cte])
new_select_expr = (
sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name))
)
new_select_expr.set("with", new_with_clause)
return new_select_expr


def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
Expand Down
22 changes: 22 additions & 0 deletions bigframes/core/guid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

_GUID_COUNTER = 0


def generate_guid(prefix="col_"):
global _GUID_COUNTER
_GUID_COUNTER += 1
return f"bfuid_{prefix}{_GUID_COUNTER}"


class SequentialUIDGenerator:
"""Produces a sequence of UIDs, such as {"t0", "t1", "c0", "t2", ...}, by
cycling through provided prefixes (e.g., "t" and "c").
Note: this function is not thread-safe.
"""

def __init__(self):
self.prefix_counters: typing.Dict[str, int] = {}

def get_uid_stream(self, prefix: str) -> typing.Generator[str, None, None]:
"""Yields a continuous stream of raw UID strings for the given prefix."""
if prefix not in self.prefix_counters:
self.prefix_counters[prefix] = 0

while True:
uid = f"{prefix}{self.prefix_counters[prefix]}"
self.prefix_counters[prefix] += 1
yield uid
18 changes: 8 additions & 10 deletions bigframes/core/rewrite/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
# limitations under the License.
from __future__ import annotations

from typing import Generator, Tuple
import typing

import bigframes.core.identifiers
import bigframes.core.nodes
from bigframes.core import identifiers, nodes


# TODO: May as well just outright remove selection nodes in this process.
def remap_variables(
root: bigframes.core.nodes.BigFrameNode,
id_generator: Generator[bigframes.core.identifiers.ColumnId, None, None],
) -> Tuple[
bigframes.core.nodes.BigFrameNode,
dict[bigframes.core.identifiers.ColumnId, bigframes.core.identifiers.ColumnId],
root: nodes.BigFrameNode,
id_generator: typing.Iterator[identifiers.ColumnId],
) -> typing.Tuple[
nodes.BigFrameNode,
dict[identifiers.ColumnId, identifiers.ColumnId],
]:
"""
Remap all variables in the BFET using the id_generator.
"""Remaps `ColumnId`s in the BFET to produce deterministic and sequential UIDs.

Note: this will convert a DAG to a tree.
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/core/compile/sqlglot/compiler_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class SQLCompilerExecutor(bigframes.session.executor.Executor):
"""Executor for SQL compilation using sqlglot."""

compiler = sqlglot.SQLGlotCompiler()
compiler = sqlglot

def to_sql(
self,
Expand All @@ -41,7 +41,9 @@ def to_sql(

# Compared with BigQueryCachingExecutor, SQLCompilerExecutor skips
# caching the subtree.
return self.compiler.compile(array_value.node, ordered=ordered)
return self.compiler.SQLGlotCompiler().compile(
array_value.node, ordered=ordered
)


class SQLCompilerSession(bigframes.session.Session):
Expand Down
Loading