Skip to content

Commit e84f232

Browse files
feat: Short circuit query for local scan (#1618)
1 parent ae83e61 commit e84f232

File tree

9 files changed

+176
-51
lines changed

9 files changed

+176
-51
lines changed

bigframes/core/local_data.py

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import pyarrow.parquet # type: ignore
3232

3333
import bigframes.core.schema as schemata
34-
import bigframes.core.utils as utils
3534
import bigframes.dtypes
3635

3736

@@ -79,7 +78,7 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> ManagedArrowTable:
7978
mat = ManagedArrowTable(
8079
pa.table(columns, names=column_names), schemata.ArraySchema(tuple(fields))
8180
)
82-
mat.validate(include_content=True)
81+
mat.validate()
8382
return mat
8483

8584
@classmethod
@@ -98,15 +97,14 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
9897
mat.validate()
9998
return mat
10099

101-
def to_parquet(
100+
def to_pyarrow_table(
102101
self,
103-
dst: Union[str, io.IOBase],
104102
*,
105103
offsets_col: Optional[str] = None,
106104
geo_format: Literal["wkb", "wkt"] = "wkt",
107105
duration_type: Literal["int", "duration"] = "duration",
108106
json_type: Literal["string"] = "string",
109-
):
107+
) -> pa.Table:
110108
pa_table = self.data
111109
if offsets_col is not None:
112110
pa_table = pa_table.append_column(
@@ -119,6 +117,23 @@ def to_parquet(
119117
f"duration as {duration_type} not yet implemented"
120118
)
121119
assert json_type == "string"
120+
return pa_table
121+
122+
def to_parquet(
123+
self,
124+
dst: Union[str, io.IOBase],
125+
*,
126+
offsets_col: Optional[str] = None,
127+
geo_format: Literal["wkb", "wkt"] = "wkt",
128+
duration_type: Literal["int", "duration"] = "duration",
129+
json_type: Literal["string"] = "string",
130+
):
131+
pa_table = self.to_pyarrow_table(
132+
offsets_col=offsets_col,
133+
geo_format=geo_format,
134+
duration_type=duration_type,
135+
json_type=json_type,
136+
)
122137
pyarrow.parquet.write_table(pa_table, where=dst)
123138

124139
def itertuples(
@@ -142,7 +157,7 @@ def itertuples(
142157
):
143158
yield tuple(row_dict.values())
144159

145-
def validate(self, include_content: bool = False):
160+
def validate(self):
146161
for bf_field, arrow_field in zip(self.schema.items, self.data.schema):
147162
expected_arrow_type = _get_managed_storage_type(bf_field.dtype)
148163
arrow_type = arrow_field.type
@@ -151,38 +166,6 @@ def validate(self, include_content: bool = False):
151166
f"Field {bf_field} has arrow array type: {arrow_type}, expected type: {expected_arrow_type}"
152167
)
153168

154-
if include_content:
155-
for batch in self.data.to_batches():
156-
for field in self.schema.items:
157-
_validate_content(batch.column(field.column), field.dtype)
158-
159-
160-
def _validate_content(array: pa.Array, dtype: bigframes.dtypes.Dtype):
161-
"""
162-
Recursively validates the content of a PyArrow Array based on the
163-
expected BigFrames dtype, focusing on complex types like JSON, structs,
164-
and arrays where the Arrow type alone isn't sufficient.
165-
"""
166-
# TODO: validate GEO data context.
167-
if dtype == bigframes.dtypes.JSON_DTYPE:
168-
values = array.to_pandas()
169-
for data in values:
170-
# Skip scalar null values to avoid `TypeError` from json.load.
171-
if not utils.is_list_like(data) and pd.isna(data):
172-
continue
173-
try:
174-
# Attempts JSON parsing.
175-
json.loads(data)
176-
except json.JSONDecodeError as e:
177-
raise ValueError(f"Invalid JSON format found: {data!r}") from e
178-
elif bigframes.dtypes.is_struct_like(dtype):
179-
for field_name, dtype in bigframes.dtypes.get_struct_fields(dtype).items():
180-
_validate_content(array.field(field_name), dtype)
181-
elif bigframes.dtypes.is_array_like(dtype):
182-
return _validate_content(
183-
array.flatten(), bigframes.dtypes.get_array_inner_type(dtype)
184-
)
185-
186169

187170
# Sequential iterator, but could split into batches and leverage parallelism for speed
188171
def _iter_table(
@@ -280,6 +263,34 @@ def _adapt_pandas_series(
280263
def _adapt_arrow_array(
281264
array: Union[pa.ChunkedArray, pa.Array]
282265
) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]:
266+
"""Normalize the array to managed storage types. Preverse shapes, only transforms values."""
267+
if pa.types.is_struct(array.type):
268+
assert isinstance(array, pa.StructArray)
269+
assert isinstance(array.type, pa.StructType)
270+
arrays = []
271+
dtypes = []
272+
pa_fields = []
273+
for i in range(array.type.num_fields):
274+
field_array, field_type = _adapt_arrow_array(array.field(i))
275+
arrays.append(field_array)
276+
dtypes.append(field_type)
277+
pa_fields.append(pa.field(array.type.field(i).name, field_array.type))
278+
struct_array = pa.StructArray.from_arrays(
279+
arrays=arrays, fields=pa_fields, mask=array.is_null()
280+
)
281+
dtype = bigframes.dtypes.struct_type(
282+
[(field.name, dtype) for field, dtype in zip(pa_fields, dtypes)]
283+
)
284+
return struct_array, dtype
285+
if pa.types.is_list(array.type):
286+
assert isinstance(array, pa.ListArray)
287+
values, values_type = _adapt_arrow_array(array.values)
288+
new_value = pa.ListArray.from_arrays(
289+
array.offsets, values, mask=array.is_null()
290+
)
291+
return new_value.fill_null([]), bigframes.dtypes.list_type(values_type)
292+
if array.type == bigframes.dtypes.JSON_ARROW_TYPE:
293+
return _canonicalize_json(array), bigframes.dtypes.JSON_DTYPE
283294
target_type = _logical_type_replacements(array.type)
284295
if target_type != array.type:
285296
# TODO: Maybe warn if lossy conversion?
@@ -292,6 +303,22 @@ def _adapt_arrow_array(
292303
return array, bf_type
293304

294305

306+
def _canonicalize_json(array: pa.Array) -> pa.Array:
307+
def _canonicalize_scalar(json_string):
308+
if json_string is None:
309+
return None
310+
# This is the canonical form that bq uses when emitting json
311+
# The sorted keys and unambiguous whitespace ensures a 1:1 mapping
312+
# between syntax and semantics.
313+
return json.dumps(
314+
json.loads(json_string), sort_keys=True, separators=(",", ":")
315+
)
316+
317+
return pa.array(
318+
[_canonicalize_scalar(value) for value in array.to_pylist()], type=pa.string()
319+
)
320+
321+
295322
def _get_managed_storage_type(dtype: bigframes.dtypes.Dtype) -> pa.DataType:
296323
if dtype in _MANAGED_STORAGE_TYPES_OVERRIDES.keys():
297324
return _MANAGED_STORAGE_TYPES_OVERRIDES[dtype]

bigframes/core/rewrite/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
1818
from bigframes.core.rewrite.order import bake_order, defer_order
1919
from bigframes.core.rewrite.pruning import column_pruning
20-
from bigframes.core.rewrite.scan_reduction import try_reduce_to_table_scan
20+
from bigframes.core.rewrite.scan_reduction import (
21+
try_reduce_to_local_scan,
22+
try_reduce_to_table_scan,
23+
)
2124
from bigframes.core.rewrite.slices import pull_up_limits, rewrite_slice
2225
from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions
2326
from bigframes.core.rewrite.windows import rewrite_range_rolling
@@ -34,4 +37,5 @@
3437
"rewrite_range_rolling",
3538
"try_reduce_to_table_scan",
3639
"bake_order",
40+
"try_reduce_to_local_scan",
3741
]

bigframes/core/rewrite/scan_reduction.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,33 @@ def try_reduce_to_table_scan(root: nodes.BigFrameNode) -> Optional[nodes.ReadTab
2828
return None
2929

3030

31+
def try_reduce_to_local_scan(node: nodes.BigFrameNode) -> Optional[nodes.ReadLocalNode]:
32+
if not all(
33+
map(
34+
lambda x: isinstance(x, (nodes.ReadLocalNode, nodes.SelectionNode)),
35+
node.unique_nodes(),
36+
)
37+
):
38+
return None
39+
result = node.bottom_up(merge_scan)
40+
if isinstance(result, nodes.ReadLocalNode):
41+
return result
42+
return None
43+
44+
3145
@functools.singledispatch
3246
def merge_scan(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
3347
return node
3448

3549

3650
@merge_scan.register
3751
def _(node: nodes.SelectionNode) -> nodes.BigFrameNode:
38-
if not isinstance(node.child, nodes.ReadTableNode):
52+
if not isinstance(node.child, (nodes.ReadTableNode, nodes.ReadLocalNode)):
3953
return node
4054
if node.has_multi_referenced_ids:
4155
return node
42-
56+
if isinstance(node, nodes.ReadLocalNode) and node.offsets_col is not None:
57+
return node
4358
selection = {
4459
aliased_ref.ref.id: aliased_ref.id for aliased_ref in node.input_output_pairs
4560
}

bigframes/dtypes.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import decimal
2020
import textwrap
2121
import typing
22-
from typing import Any, Dict, List, Literal, Union
22+
from typing import Any, Dict, List, Literal, Sequence, Union
2323

2424
import bigframes_vendored.constants as constants
2525
import db_dtypes # type: ignore
@@ -370,6 +370,19 @@ def get_array_inner_type(type_: ExpressionType) -> Dtype:
370370
return arrow_dtype_to_bigframes_dtype(list_type.value_type)
371371

372372

373+
def list_type(values_type: Dtype) -> Dtype:
374+
"""Create a list dtype with given value type."""
375+
return pd.ArrowDtype(pa.list_(bigframes_dtype_to_arrow_dtype(values_type)))
376+
377+
378+
def struct_type(fields: Sequence[tuple[str, Dtype]]) -> Dtype:
379+
"""Create a struct dtype with give fields names and types."""
380+
pa_fields = [
381+
pa.field(str, bigframes_dtype_to_arrow_dtype(dtype)) for str, dtype in fields
382+
]
383+
return pd.ArrowDtype(pa.struct(pa_fields))
384+
385+
373386
_ORDERABLE_SIMPLE_TYPES = set(
374387
mapping.dtype for mapping in SIMPLE_TYPES if mapping.orderable
375388
)

bigframes/session/bq_caching_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import bigframes.dtypes
3636
import bigframes.exceptions as bfe
3737
import bigframes.features
38-
from bigframes.session import executor, read_api_execution
38+
from bigframes.session import executor, local_scan_executor, read_api_execution
3939
import bigframes.session._io.bigquery as bq_io
4040
import bigframes.session.metrics
4141
import bigframes.session.planner
@@ -84,6 +84,7 @@ def __init__(
8484
bqstoragereadclient=bqstoragereadclient,
8585
project=self.bqclient.project,
8686
),
87+
local_scan_executor.LocalScanExecutor(),
8788
)
8889

8990
def to_sql(
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import Optional
17+
18+
from bigframes.core import bigframe_node, rewrite
19+
from bigframes.session import executor, semi_executor
20+
21+
22+
class LocalScanExecutor(semi_executor.SemiExecutor):
23+
"""
24+
Executes plans reducible to a arrow table scan.
25+
"""
26+
27+
def execute(
28+
self,
29+
plan: bigframe_node.BigFrameNode,
30+
ordered: bool,
31+
peek: Optional[int] = None,
32+
) -> Optional[executor.ExecuteResult]:
33+
node = rewrite.try_reduce_to_local_scan(plan)
34+
if not node:
35+
return None
36+
37+
# TODO: Can support some slicing, sorting
38+
def iterator_supplier():
39+
offsets_col = (
40+
node.offsets_col.sql if (node.offsets_col is not None) else None
41+
)
42+
arrow_table = node.local_data_source.to_pyarrow_table(
43+
offsets_col=offsets_col
44+
)
45+
if peek:
46+
arrow_table = arrow_table.slice(0, peek)
47+
48+
needed_cols = [item.source_id for item in node.scan_list.items]
49+
if offsets_col is not None:
50+
needed_cols.append(offsets_col)
51+
52+
arrow_table = arrow_table.select(needed_cols)
53+
arrow_table = arrow_table.rename_columns(
54+
{item.source_id: item.id.sql for item in node.scan_list.items}
55+
)
56+
yield from arrow_table.to_batches()
57+
58+
total_rows = node.row_count
59+
if (peek is not None) and (total_rows is not None):
60+
total_rows = min(peek, total_rows)
61+
62+
return executor.ExecuteResult(
63+
arrow_batches=iterator_supplier,
64+
schema=plan.schema,
65+
query_job=None,
66+
total_bytes=None,
67+
total_rows=total_rows,
68+
)

tests/system/small/bigquery/test_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_json_set_w_more_pairs():
6666
s, json_path_value_pairs=[("$.a", 1), ("$.b", 2), ("$.a", [3, 4, 5])]
6767
)
6868

69-
expected_json = ['{"a": 3, "b": 2}', '{"a": 4, "b": 2}', '{"a": 5, "b": 2, "c": 1}']
69+
expected_json = ['{"a": 3,"b":2}', '{"a":4,"b": 2}', '{"a": 5,"b":2,"c":1}']
7070
expected = bpd.Series(expected_json, dtype=dtypes.JSON_DTYPE)
7171

7272
pd.testing.assert_series_equal(actual.to_pandas(), expected.to_pandas())

tests/system/small/test_session.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import io
15+
import json
1516
import random
1617
import re
1718
import tempfile
@@ -32,7 +33,6 @@
3233
import pytest
3334

3435
import bigframes
35-
import bigframes.core.indexes.base
3636
import bigframes.dataframe
3737
import bigframes.dtypes
3838
import bigframes.ml.linear_model
@@ -990,10 +990,7 @@ def test_read_pandas_json_series_w_invalid_json(session, write_engine):
990990
]
991991
pd_s = pd.Series(json_data, dtype=bigframes.dtypes.JSON_DTYPE)
992992

993-
with pytest.raises(
994-
ValueError,
995-
match="Invalid JSON format found",
996-
):
993+
with pytest.raises(json.JSONDecodeError):
997994
session.read_pandas(pd_s, write_engine=write_engine)
998995

999996

@@ -1101,7 +1098,7 @@ def test_read_pandas_w_nested_invalid_json(session, write_engine):
11011098
),
11021099
)
11031100

1104-
with pytest.raises(ValueError, match="Invalid JSON format found"):
1101+
with pytest.raises(json.JSONDecodeError):
11051102
session.read_pandas(pd_s, write_engine=write_engine)
11061103

11071104

tests/unit/test_local_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def small_inline_frame() -> pd.DataFrame:
4141
"bools": pd.Series([True, None, False], dtype="boolean"),
4242
"strings": pd.Series(["b", "aa", "ccc"], dtype="string[pyarrow]"),
4343
"intLists": pd.Series(
44-
[[1, 2, 3], [4, 5, 6, 7], None],
44+
[[1, 2, 3], [4, 5, 6, 7], []],
4545
dtype=pd.ArrowDtype(pa.list_(pa.int64())),
4646
),
4747
},

0 commit comments

Comments
 (0)