Skip to content

feat: Short circuit query for local scan #1618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 65 additions & 38 deletions bigframes/core/local_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import pyarrow.parquet # type: ignore

import bigframes.core.schema as schemata
import bigframes.core.utils as utils
import bigframes.dtypes


Expand Down Expand Up @@ -79,7 +78,7 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> ManagedArrowTable:
mat = ManagedArrowTable(
pa.table(columns, names=column_names), schemata.ArraySchema(tuple(fields))
)
mat.validate(include_content=True)
mat.validate()
return mat

@classmethod
Expand All @@ -98,15 +97,14 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
mat.validate()
return mat

def to_parquet(
def to_pyarrow_table(
self,
dst: Union[str, io.IOBase],
*,
offsets_col: Optional[str] = None,
geo_format: Literal["wkb", "wkt"] = "wkt",
duration_type: Literal["int", "duration"] = "duration",
json_type: Literal["string"] = "string",
):
) -> pa.Table:
pa_table = self.data
if offsets_col is not None:
pa_table = pa_table.append_column(
Expand All @@ -119,6 +117,23 @@ def to_parquet(
f"duration as {duration_type} not yet implemented"
)
assert json_type == "string"
return pa_table

def to_parquet(
self,
dst: Union[str, io.IOBase],
*,
offsets_col: Optional[str] = None,
geo_format: Literal["wkb", "wkt"] = "wkt",
duration_type: Literal["int", "duration"] = "duration",
json_type: Literal["string"] = "string",
):
pa_table = self.to_pyarrow_table(
offsets_col=offsets_col,
geo_format=geo_format,
duration_type=duration_type,
json_type=json_type,
)
pyarrow.parquet.write_table(pa_table, where=dst)

def itertuples(
Expand All @@ -142,7 +157,7 @@ def itertuples(
):
yield tuple(row_dict.values())

def validate(self, include_content: bool = False):
def validate(self):
for bf_field, arrow_field in zip(self.schema.items, self.data.schema):
expected_arrow_type = _get_managed_storage_type(bf_field.dtype)
arrow_type = arrow_field.type
Expand All @@ -151,38 +166,6 @@ def validate(self, include_content: bool = False):
f"Field {bf_field} has arrow array type: {arrow_type}, expected type: {expected_arrow_type}"
)

if include_content:
for batch in self.data.to_batches():
for field in self.schema.items:
_validate_content(batch.column(field.column), field.dtype)


def _validate_content(array: pa.Array, dtype: bigframes.dtypes.Dtype):
"""
Recursively validates the content of a PyArrow Array based on the
expected BigFrames dtype, focusing on complex types like JSON, structs,
and arrays where the Arrow type alone isn't sufficient.
"""
# TODO: validate GEO data context.
if dtype == bigframes.dtypes.JSON_DTYPE:
values = array.to_pandas()
for data in values:
# Skip scalar null values to avoid `TypeError` from json.load.
if not utils.is_list_like(data) and pd.isna(data):
continue
try:
# Attempts JSON parsing.
json.loads(data)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format found: {data!r}") from e
elif bigframes.dtypes.is_struct_like(dtype):
for field_name, dtype in bigframes.dtypes.get_struct_fields(dtype).items():
_validate_content(array.field(field_name), dtype)
elif bigframes.dtypes.is_array_like(dtype):
return _validate_content(
array.flatten(), bigframes.dtypes.get_array_inner_type(dtype)
)


# Sequential iterator, but could split into batches and leverage parallelism for speed
def _iter_table(
Expand Down Expand Up @@ -280,6 +263,34 @@ def _adapt_pandas_series(
def _adapt_arrow_array(
array: Union[pa.ChunkedArray, pa.Array]
) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]:
"""Normalize the array to managed storage types. Preverse shapes, only transforms values."""
if pa.types.is_struct(array.type):
assert isinstance(array, pa.StructArray)
assert isinstance(array.type, pa.StructType)
arrays = []
dtypes = []
pa_fields = []
for i in range(array.type.num_fields):
field_array, field_type = _adapt_arrow_array(array.field(i))
arrays.append(field_array)
dtypes.append(field_type)
pa_fields.append(pa.field(array.type.field(i).name, field_array.type))
struct_array = pa.StructArray.from_arrays(
arrays=arrays, fields=pa_fields, mask=array.is_null()
)
dtype = bigframes.dtypes.struct_type(
[(field.name, dtype) for field, dtype in zip(pa_fields, dtypes)]
)
return struct_array, dtype
if pa.types.is_list(array.type):
assert isinstance(array, pa.ListArray)
values, values_type = _adapt_arrow_array(array.values)
new_value = pa.ListArray.from_arrays(
array.offsets, values, mask=array.is_null()
)
return new_value.fill_null([]), bigframes.dtypes.list_type(values_type)
if array.type == bigframes.dtypes.JSON_ARROW_TYPE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With _canonicalize_json, you probably can remove _validate_content method in this module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

return _canonicalize_json(array), bigframes.dtypes.JSON_DTYPE
target_type = _logical_type_replacements(array.type)
if target_type != array.type:
# TODO: Maybe warn if lossy conversion?
Expand All @@ -292,6 +303,22 @@ def _adapt_arrow_array(
return array, bf_type


def _canonicalize_json(array: pa.Array) -> pa.Array:
def _canonicalize_scalar(json_string):
if json_string is None:
return None
# This is the canonical form that bq uses when emitting json
# The sorted keys and unambiguous whitespace ensures a 1:1 mapping
# between syntax and semantics.
return json.dumps(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments why we need to canonicalize json here? Refer to:
# sort_keys=True sorts dictionary keys before serialization, making
# JSON comparisons deterministic.
# separators=(',', ':') eliminate whitespace to get the most compact
# JSON representation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment

json.loads(json_string), sort_keys=True, separators=(",", ":")
)

return pa.array(
[_canonicalize_scalar(value) for value in array.to_pylist()], type=pa.string()
)


def _get_managed_storage_type(dtype: bigframes.dtypes.Dtype) -> pa.DataType:
if dtype in _MANAGED_STORAGE_TYPES_OVERRIDES.keys():
return _MANAGED_STORAGE_TYPES_OVERRIDES[dtype]
Expand Down
6 changes: 5 additions & 1 deletion bigframes/core/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
from bigframes.core.rewrite.order import bake_order, defer_order
from bigframes.core.rewrite.pruning import column_pruning
from bigframes.core.rewrite.scan_reduction import try_reduce_to_table_scan
from bigframes.core.rewrite.scan_reduction import (
try_reduce_to_local_scan,
try_reduce_to_table_scan,
)
from bigframes.core.rewrite.slices import pull_up_limits, rewrite_slice
from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions
from bigframes.core.rewrite.windows import rewrite_range_rolling
Expand All @@ -34,4 +37,5 @@
"rewrite_range_rolling",
"try_reduce_to_table_scan",
"bake_order",
"try_reduce_to_local_scan",
]
19 changes: 17 additions & 2 deletions bigframes/core/rewrite/scan_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,33 @@ def try_reduce_to_table_scan(root: nodes.BigFrameNode) -> Optional[nodes.ReadTab
return None


def try_reduce_to_local_scan(node: nodes.BigFrameNode) -> Optional[nodes.ReadLocalNode]:
if not all(
map(
lambda x: isinstance(x, (nodes.ReadLocalNode, nodes.SelectionNode)),
node.unique_nodes(),
)
):
return None
result = node.bottom_up(merge_scan)
if isinstance(result, nodes.ReadLocalNode):
return result
return None


@functools.singledispatch
def merge_scan(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
return node


@merge_scan.register
def _(node: nodes.SelectionNode) -> nodes.BigFrameNode:
if not isinstance(node.child, nodes.ReadTableNode):
if not isinstance(node.child, (nodes.ReadTableNode, nodes.ReadLocalNode)):
return node
if node.has_multi_referenced_ids:
return node

if isinstance(node, nodes.ReadLocalNode) and node.offsets_col is not None:
return node
selection = {
aliased_ref.ref.id: aliased_ref.id for aliased_ref in node.input_output_pairs
}
Expand Down
15 changes: 14 additions & 1 deletion bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import decimal
import textwrap
import typing
from typing import Any, Dict, List, Literal, Union
from typing import Any, Dict, List, Literal, Sequence, Union

import bigframes_vendored.constants as constants
import db_dtypes # type: ignore
Expand Down Expand Up @@ -370,6 +370,19 @@ def get_array_inner_type(type_: ExpressionType) -> Dtype:
return arrow_dtype_to_bigframes_dtype(list_type.value_type)


def list_type(values_type: Dtype) -> Dtype:
"""Create a list dtype with given value type."""
return pd.ArrowDtype(pa.list_(bigframes_dtype_to_arrow_dtype(values_type)))


def struct_type(fields: Sequence[tuple[str, Dtype]]) -> Dtype:
"""Create a struct dtype with give fields names and types."""
pa_fields = [
pa.field(str, bigframes_dtype_to_arrow_dtype(dtype)) for str, dtype in fields
]
return pd.ArrowDtype(pa.struct(pa_fields))


_ORDERABLE_SIMPLE_TYPES = set(
mapping.dtype for mapping in SIMPLE_TYPES if mapping.orderable
)
Expand Down
3 changes: 2 additions & 1 deletion bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import bigframes.dtypes
import bigframes.exceptions as bfe
import bigframes.features
from bigframes.session import executor, read_api_execution
from bigframes.session import executor, local_scan_executor, read_api_execution
import bigframes.session._io.bigquery as bq_io
import bigframes.session.metrics
import bigframes.session.planner
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
bqstoragereadclient=bqstoragereadclient,
project=self.bqclient.project,
),
local_scan_executor.LocalScanExecutor(),
)

def to_sql(
Expand Down
68 changes: 68 additions & 0 deletions bigframes/session/local_scan_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Optional

from bigframes.core import bigframe_node, rewrite
from bigframes.session import executor, semi_executor


class LocalScanExecutor(semi_executor.SemiExecutor):
"""
Executes plans reducible to a arrow table scan.
"""

def execute(
self,
plan: bigframe_node.BigFrameNode,
ordered: bool,
peek: Optional[int] = None,
) -> Optional[executor.ExecuteResult]:
node = rewrite.try_reduce_to_local_scan(plan)
if not node:
return None

# TODO: Can support some slicing, sorting
def iterator_supplier():
offsets_col = (
node.offsets_col.sql if (node.offsets_col is not None) else None
)
arrow_table = node.local_data_source.to_pyarrow_table(
offsets_col=offsets_col
)
if peek:
arrow_table = arrow_table.slice(0, peek)

needed_cols = [item.source_id for item in node.scan_list.items]
if offsets_col is not None:
needed_cols.append(offsets_col)

arrow_table = arrow_table.select(needed_cols)
arrow_table = arrow_table.rename_columns(
{item.source_id: item.id.sql for item in node.scan_list.items}
)
yield from arrow_table.to_batches()

total_rows = node.row_count
if (peek is not None) and (total_rows is not None):
total_rows = min(peek, total_rows)

return executor.ExecuteResult(
arrow_batches=iterator_supplier,
schema=plan.schema,
query_job=None,
total_bytes=None,
total_rows=total_rows,
)
2 changes: 1 addition & 1 deletion tests/system/small/bigquery/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_json_set_w_more_pairs():
s, json_path_value_pairs=[("$.a", 1), ("$.b", 2), ("$.a", [3, 4, 5])]
)

expected_json = ['{"a": 3, "b": 2}', '{"a": 4, "b": 2}', '{"a": 5, "b": 2, "c": 1}']
expected_json = ['{"a": 3,"b":2}', '{"a":4,"b": 2}', '{"a": 5,"b":2,"c":1}']
expected = bpd.Series(expected_json, dtype=dtypes.JSON_DTYPE)

pd.testing.assert_series_equal(actual.to_pandas(), expected.to_pandas())
Expand Down
9 changes: 3 additions & 6 deletions tests/system/small/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
import random
import re
import tempfile
Expand All @@ -32,7 +33,6 @@
import pytest

import bigframes
import bigframes.core.indexes.base
import bigframes.dataframe
import bigframes.dtypes
import bigframes.ml.linear_model
Expand Down Expand Up @@ -990,10 +990,7 @@ def test_read_pandas_json_series_w_invalid_json(session, write_engine):
]
pd_s = pd.Series(json_data, dtype=bigframes.dtypes.JSON_DTYPE)

with pytest.raises(
ValueError,
match="Invalid JSON format found",
):
with pytest.raises(json.JSONDecodeError):
session.read_pandas(pd_s, write_engine=write_engine)


Expand Down Expand Up @@ -1101,7 +1098,7 @@ def test_read_pandas_w_nested_invalid_json(session, write_engine):
),
)

with pytest.raises(ValueError, match="Invalid JSON format found"):
with pytest.raises(json.JSONDecodeError):
session.read_pandas(pd_s, write_engine=write_engine)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_local_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def small_inline_frame() -> pd.DataFrame:
"bools": pd.Series([True, None, False], dtype="boolean"),
"strings": pd.Series(["b", "aa", "ccc"], dtype="string[pyarrow]"),
"intLists": pd.Series(
[[1, 2, 3], [4, 5, 6, 7], None],
[[1, 2, 3], [4, 5, 6, 7], []],
dtype=pd.ArrowDtype(pa.list_(pa.int64())),
),
},
Expand Down