Skip to content

Commit a89e770

Browse files
feat: Short circuit query for local scan
1 parent 087a32a commit a89e770

File tree

6 files changed

+118
-8
lines changed

6 files changed

+118
-8
lines changed

bigframes/core/local_data.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,14 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
9898
mat.validate()
9999
return mat
100100

101-
def to_parquet(
101+
def to_pyarrow_table(
102102
self,
103-
dst: Union[str, io.IOBase],
104103
*,
105104
offsets_col: Optional[str] = None,
106105
geo_format: Literal["wkb", "wkt"] = "wkt",
107106
duration_type: Literal["int", "duration"] = "duration",
108107
json_type: Literal["string"] = "string",
109-
):
108+
) -> pa.Table:
110109
pa_table = self.data
111110
if offsets_col is not None:
112111
pa_table = pa_table.append_column(
@@ -119,6 +118,23 @@ def to_parquet(
119118
f"duration as {duration_type} not yet implemented"
120119
)
121120
assert json_type == "string"
121+
return pa_table
122+
123+
def to_parquet(
124+
self,
125+
dst: Union[str, io.IOBase],
126+
*,
127+
offsets_col: Optional[str] = None,
128+
geo_format: Literal["wkb", "wkt"] = "wkt",
129+
duration_type: Literal["int", "duration"] = "duration",
130+
json_type: Literal["string"] = "string",
131+
):
132+
pa_table = self.to_pyarrow_table(
133+
offsets_col=offsets_col,
134+
geo_format=geo_format,
135+
duration_type=duration_type,
136+
json_type=json_type,
137+
)
122138
pyarrow.parquet.write_table(pa_table, where=dst)
123139

124140
def itertuples(
@@ -280,6 +296,8 @@ def _adapt_pandas_series(
280296
def _adapt_arrow_array(
281297
array: Union[pa.ChunkedArray, pa.Array]
282298
) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]:
299+
if array.type == bigframes.dtypes.JSON_ARROW_TYPE:
300+
return _canonicalize_json(array), bigframes.dtypes.JSON_DTYPE
283301
target_type = _logical_type_replacements(array.type)
284302
if target_type != array.type:
285303
# TODO: Maybe warn if lossy conversion?
@@ -292,6 +310,19 @@ def _adapt_arrow_array(
292310
return array, bf_type
293311

294312

313+
def _canonicalize_json(array: pa.Array) -> pa.Array:
314+
def _canonicalize_scalar(json_string):
315+
if json_string is None:
316+
return None
317+
return json.dumps(
318+
json.loads(json_string), sort_keys=True, separators=(",", ":")
319+
)
320+
321+
return pa.array(
322+
[_canonicalize_scalar(value) for value in array.to_pylist()], type=pa.string()
323+
)
324+
325+
295326
def _get_managed_storage_type(dtype: bigframes.dtypes.Dtype) -> pa.DataType:
296327
if dtype in _MANAGED_STORAGE_TYPES_OVERRIDES.keys():
297328
return _MANAGED_STORAGE_TYPES_OVERRIDES[dtype]

bigframes/core/rewrite/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
1818
from bigframes.core.rewrite.order import pull_up_order
1919
from bigframes.core.rewrite.pruning import column_pruning
20-
from bigframes.core.rewrite.scan_reduction import try_reduce_to_table_scan
20+
from bigframes.core.rewrite.scan_reduction import (
21+
try_reduce_to_local_scan,
22+
try_reduce_to_table_scan,
23+
)
2124
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice
2225
from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions
2326
from bigframes.core.rewrite.windows import rewrite_range_rolling
@@ -33,4 +36,5 @@
3336
"column_pruning",
3437
"rewrite_range_rolling",
3538
"try_reduce_to_table_scan",
39+
"try_reduce_to_local_scan",
3640
]

bigframes/core/rewrite/scan_reduction.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,28 @@ def try_reduce_to_table_scan(root: nodes.BigFrameNode) -> Optional[nodes.ReadTab
2828
return None
2929

3030

31+
def try_reduce_to_local_scan(node: nodes.BigFrameNode) -> Optional[nodes.ReadLocalNode]:
32+
if not all(
33+
map(
34+
lambda x: isinstance(x, (nodes.ReadLocalNode, nodes.SelectionNode)),
35+
node.unique_nodes(),
36+
)
37+
):
38+
return None
39+
result = node.bottom_up(merge_scan)
40+
if isinstance(result, nodes.ReadLocalNode):
41+
return result
42+
return None
43+
44+
3145
@functools.singledispatch
3246
def merge_scan(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
3347
return node
3448

3549

3650
@merge_scan.register
3751
def _(node: nodes.SelectionNode) -> nodes.BigFrameNode:
38-
if not isinstance(node.child, nodes.ReadTableNode):
52+
if not isinstance(node.child, (nodes.ReadTableNode, nodes.ReadLocalNode)):
3953
return node
4054
if node.has_multi_referenced_ids:
4155
return node

bigframes/session/bq_caching_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import bigframes.dtypes
3636
import bigframes.exceptions as bfe
3737
import bigframes.features
38-
from bigframes.session import executor, read_api_execution
38+
from bigframes.session import executor, local_scan_execution, read_api_execution
3939
import bigframes.session._io.bigquery as bq_io
4040
import bigframes.session.metrics
4141
import bigframes.session.planner
@@ -84,6 +84,7 @@ def __init__(
8484
bqstoragereadclient=bqstoragereadclient,
8585
project=self.bqclient.project,
8686
),
87+
local_scan_execution.LocalScanExecutor(),
8788
)
8889

8990
def to_sql(
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import Optional
17+
18+
from bigframes.core import bigframe_node, rewrite
19+
from bigframes.session import executor, semi_executor
20+
21+
22+
class LocalScanExecutor(semi_executor.SemiExecutor):
23+
"""
24+
Executes plans reducible to a arrow table scan.
25+
"""
26+
27+
def execute(
28+
self,
29+
plan: bigframe_node.BigFrameNode,
30+
ordered: bool,
31+
peek: Optional[int] = None,
32+
) -> Optional[executor.ExecuteResult]:
33+
node = rewrite.try_reduce_to_local_scan(plan)
34+
if not node:
35+
return None
36+
37+
# TODO: Can support some slicing, sorting
38+
def iterator_supplier():
39+
offsets_col = (
40+
node.offsets_col.sql if (node.offsets_col is not None) else None
41+
)
42+
arrow_table = node.local_data_source.to_pyarrow_table(
43+
offsets_col=offsets_col
44+
)
45+
if peek:
46+
arrow_table = arrow_table.slice(0, peek)
47+
for batch in arrow_table.to_batches():
48+
batch = batch.select([item.source_id for item in node.scan_list.items])
49+
batch = batch.rename_columns(
50+
{item.source_id: item.id.sql for item in node.scan_list.items}
51+
)
52+
yield batch
53+
54+
return executor.ExecuteResult(
55+
arrow_batches=iterator_supplier,
56+
schema=plan.schema,
57+
query_job=None,
58+
total_bytes=None,
59+
total_rows=peek or node.local_data_source.metadata.row_count,
60+
)

tests/system/small/bigquery/test_json.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_json_set_w_more_pairs():
6666
s, json_path_value_pairs=[("$.a", 1), ("$.b", 2), ("$.a", [3, 4, 5])]
6767
)
6868

69-
expected_json = ['{"a": 3, "b": 2}', '{"a": 4, "b": 2}', '{"a": 5, "b": 2, "c": 1}']
69+
expected_json = ['{"a": 3,"b":2}', '{"a":4,"b": 2}', '{"a": 5,"b":2,"c":1}']
7070
expected = bpd.Series(expected_json, dtype=dtypes.JSON_DTYPE)
7171

7272
pd.testing.assert_series_equal(actual.to_pandas(), expected.to_pandas())
@@ -155,7 +155,7 @@ def test_json_extract_array_from_json_strings():
155155
)
156156
actual = bbq.json_extract_array(s, "$.a")
157157
expected = bpd.Series(
158-
[['"ab"', '"2"', '"3 xy"'], [], ['"4"', '"5"'], None],
158+
[['"ab"', '"2"', '"3 xy"'], [], ['"4"', '"5"'], []],
159159
dtype=pd.ArrowDtype(pa.list_(pa.string())),
160160
)
161161

0 commit comments

Comments
 (0)