Skip to content

Commit a944c58

Browse files
feat: Short circuit query for local scan
1 parent 087a32a commit a944c58

File tree

6 files changed

+102
-7
lines changed

6 files changed

+102
-7
lines changed

bigframes/core/local_data.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,14 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
9898
mat.validate()
9999
return mat
100100

101-
def to_parquet(
101+
def to_pyarrow_table(
102102
self,
103-
dst: Union[str, io.IOBase],
104103
*,
105104
offsets_col: Optional[str] = None,
106105
geo_format: Literal["wkb", "wkt"] = "wkt",
107106
duration_type: Literal["int", "duration"] = "duration",
108107
json_type: Literal["string"] = "string",
109-
):
108+
) -> pa.Table:
110109
pa_table = self.data
111110
if offsets_col is not None:
112111
pa_table = pa_table.append_column(
@@ -119,6 +118,23 @@ def to_parquet(
119118
f"duration as {duration_type} not yet implemented"
120119
)
121120
assert json_type == "string"
121+
return pa_table
122+
123+
def to_parquet(
124+
self,
125+
dst: Union[str, io.IOBase],
126+
*,
127+
offsets_col: Optional[str] = None,
128+
geo_format: Literal["wkb", "wkt"] = "wkt",
129+
duration_type: Literal["int", "duration"] = "duration",
130+
json_type: Literal["string"] = "string",
131+
):
132+
pa_table = self.to_pyarrow_table(
133+
offsets_col=offsets_col,
134+
geo_format=geo_format,
135+
duration_type=duration_type,
136+
json_type=json_type,
137+
)
122138
pyarrow.parquet.write_table(pa_table, where=dst)
123139

124140
def itertuples(

bigframes/core/rewrite/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
1818
from bigframes.core.rewrite.order import pull_up_order
1919
from bigframes.core.rewrite.pruning import column_pruning
20-
from bigframes.core.rewrite.scan_reduction import try_reduce_to_table_scan
20+
from bigframes.core.rewrite.scan_reduction import (
21+
try_reduce_to_local_scan,
22+
try_reduce_to_table_scan,
23+
)
2124
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice
2225
from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions
2326
from bigframes.core.rewrite.windows import rewrite_range_rolling
@@ -33,4 +36,5 @@
3336
"column_pruning",
3437
"rewrite_range_rolling",
3538
"try_reduce_to_table_scan",
39+
"try_reduce_to_local_scan",
3640
]

bigframes/core/rewrite/scan_reduction.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,28 @@ def try_reduce_to_table_scan(root: nodes.BigFrameNode) -> Optional[nodes.ReadTab
2828
return None
2929

3030

31+
def try_reduce_to_local_scan(node: nodes.BigFrameNode) -> Optional[nodes.ReadLocalNode]:
32+
if not all(
33+
map(
34+
lambda x: isinstance(x, (nodes.ReadLocalNode, nodes.SelectionNode)),
35+
node.unique_nodes(),
36+
)
37+
):
38+
return None
39+
result = node.bottom_up(merge_scan)
40+
if isinstance(result, nodes.ReadLocalNode):
41+
return result
42+
return None
43+
44+
3145
@functools.singledispatch
3246
def merge_scan(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
3347
return node
3448

3549

3650
@merge_scan.register
3751
def _(node: nodes.SelectionNode) -> nodes.BigFrameNode:
38-
if not isinstance(node.child, nodes.ReadTableNode):
52+
if not isinstance(node.child, (nodes.ReadTableNode, nodes.ReadLocalNode)):
3953
return node
4054
if node.has_multi_referenced_ids:
4155
return node

bigframes/session/bq_caching_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import bigframes.dtypes
3636
import bigframes.exceptions as bfe
3737
import bigframes.features
38-
from bigframes.session import executor, read_api_execution
38+
from bigframes.session import executor, local_scan_execution, read_api_execution
3939
import bigframes.session._io.bigquery as bq_io
4040
import bigframes.session.metrics
4141
import bigframes.session.planner
@@ -84,6 +84,7 @@ def __init__(
8484
bqstoragereadclient=bqstoragereadclient,
8585
project=self.bqclient.project,
8686
),
87+
local_scan_execution.LocalScanExecutor(),
8788
)
8889

8990
def to_sql(
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import Optional
17+
18+
from bigframes.core import bigframe_node, rewrite
19+
from bigframes.session import executor, semi_executor
20+
21+
22+
class LocalScanExecutor(semi_executor.SemiExecutor):
23+
"""
24+
Executes plans reducible to a arrow table scan.
25+
"""
26+
27+
def execute(
28+
self,
29+
plan: bigframe_node.BigFrameNode,
30+
ordered: bool,
31+
peek: Optional[int] = None,
32+
) -> Optional[executor.ExecuteResult]:
33+
node = rewrite.try_reduce_to_local_scan(plan)
34+
if not node:
35+
return None
36+
37+
# TODO: Can support some slicing, sorting
38+
def iterator_supplier():
39+
offsets_col = (
40+
node.offsets_col.sql if (node.offsets_col is not None) else None
41+
)
42+
arrow_table = node.local_data_source.to_pyarrow_table(
43+
offsets_col=offsets_col
44+
)
45+
if peek:
46+
arrow_table = arrow_table.slice(0, peek)
47+
for batch in arrow_table.to_batches():
48+
batch = batch.select([item.source_id for item in node.scan_list.items])
49+
batch = batch.rename_columns(
50+
{item.source_id: item.id.sql for item in node.scan_list.items}
51+
)
52+
yield batch
53+
54+
return executor.ExecuteResult(
55+
arrow_batches=iterator_supplier,
56+
schema=plan.schema,
57+
query_job=None,
58+
total_bytes=None,
59+
total_rows=peek or node.local_data_source.metadata.row_count,
60+
)

tests/system/small/bigquery/test_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_json_set_w_more_pairs():
6666
s, json_path_value_pairs=[("$.a", 1), ("$.b", 2), ("$.a", [3, 4, 5])]
6767
)
6868

69-
expected_json = ['{"a": 3, "b": 2}', '{"a": 4, "b": 2}', '{"a": 5, "b": 2, "c": 1}']
69+
expected_json = ['{"a":3,"b":2}', '{"a":4,"b":2}', '{"a":5,"b":2,"c":1}']
7070
expected = bpd.Series(expected_json, dtype=dtypes.JSON_DTYPE)
7171

7272
pd.testing.assert_series_equal(actual.to_pandas(), expected.to_pandas())

0 commit comments

Comments
 (0)