Skip to content

Commit ec35cd8

Browse files
feat: Short circuit query for local scan
1 parent 40d6960 commit ec35cd8

File tree

5 files changed

+99
-6
lines changed

5 files changed

+99
-6
lines changed

bigframes/core/local_data.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,14 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
9696
schemata.ArraySchema(tuple(fields)),
9797
)
9898

99-
def to_parquet(
99+
def to_arrow(
100100
self,
101-
dst: Union[str, io.IOBase],
102101
*,
103102
offsets_col: Optional[str] = None,
104103
geo_format: Literal["wkb", "wkt"] = "wkt",
105104
duration_type: Literal["int", "duration"] = "duration",
106105
json_type: Literal["string"] = "string",
107-
):
106+
) -> pa.Table:
108107
pa_table = self.data
109108
if offsets_col is not None:
110109
pa_table = pa_table.append_column(
@@ -117,6 +116,23 @@ def to_parquet(
117116
f"duration as {duration_type} not yet implemented"
118117
)
119118
assert json_type == "string"
119+
return pa_table
120+
121+
def to_parquet(
122+
self,
123+
dst: Union[str, io.IOBase],
124+
*,
125+
offsets_col: Optional[str] = None,
126+
geo_format: Literal["wkb", "wkt"] = "wkt",
127+
duration_type: Literal["int", "duration"] = "duration",
128+
json_type: Literal["string"] = "string",
129+
):
130+
pa_table = self.to_arrow(
131+
offsets_col=offsets_col,
132+
geo_format=geo_format,
133+
duration_type=duration_type,
134+
json_type=json_type,
135+
)
120136
pyarrow.parquet.write_table(pa_table, where=dst)
121137

122138
def itertuples(

bigframes/core/rewrite/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
1818
from bigframes.core.rewrite.order import pull_up_order
1919
from bigframes.core.rewrite.pruning import column_pruning
20-
from bigframes.core.rewrite.scan_reduction import try_reduce_to_table_scan
20+
from bigframes.core.rewrite.scan_reduction import (
21+
try_reduce_to_local_scan,
22+
try_reduce_to_table_scan,
23+
)
2124
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice
2225
from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions
2326
from bigframes.core.rewrite.windows import rewrite_range_rolling
@@ -33,4 +36,5 @@
3336
"column_pruning",
3437
"rewrite_range_rolling",
3538
"try_reduce_to_table_scan",
39+
"try_reduce_to_local_scan",
3640
]

bigframes/core/rewrite/scan_reduction.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,28 @@ def try_reduce_to_table_scan(root: nodes.BigFrameNode) -> Optional[nodes.ReadTab
2828
return None
2929

3030

31+
def try_reduce_to_local_scan(node: nodes.BigFrameNode) -> Optional[nodes.ReadLocalNode]:
32+
if not all(
33+
map(
34+
lambda x: isinstance(x, (nodes.ReadLocalNode, nodes.SelectionNode)),
35+
node.unique_nodes(),
36+
)
37+
):
38+
return None
39+
result = node.bottom_up(merge_scan)
40+
if isinstance(result, nodes.ReadLocalNode):
41+
return result
42+
return None
43+
44+
3145
@functools.singledispatch
3246
def merge_scan(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
3347
return node
3448

3549

3650
@merge_scan.register
3751
def _(node: nodes.SelectionNode) -> nodes.BigFrameNode:
38-
if not isinstance(node.child, nodes.ReadTableNode):
52+
if not isinstance(node.child, (nodes.ReadTableNode, nodes.ReadLocalNode)):
3953
return node
4054
if node.has_multi_referenced_ids:
4155
return node

bigframes/session/bq_caching_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import bigframes.dtypes
3636
import bigframes.exceptions as bfe
3737
import bigframes.features
38-
from bigframes.session import executor, read_api_execution
38+
from bigframes.session import executor, local_scan_execution, read_api_execution
3939
import bigframes.session._io.bigquery as bq_io
4040
import bigframes.session.metrics
4141
import bigframes.session.planner
@@ -84,6 +84,7 @@ def __init__(
8484
bqstoragereadclient=bqstoragereadclient,
8585
project=self.bqclient.project,
8686
),
87+
local_scan_execution.LocalScanExecutor(),
8788
)
8889

8990
def to_sql(
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import Optional
17+
18+
from bigframes.core import bigframe_node, rewrite
19+
from bigframes.session import executor, semi_executor
20+
21+
22+
class LocalScanExecutor(semi_executor.SemiExecutor):
23+
"""
24+
Executes plans reducible to a arrow table scan.
25+
"""
26+
27+
def execute(
28+
self,
29+
plan: bigframe_node.BigFrameNode,
30+
ordered: bool,
31+
peek: Optional[int] = None,
32+
) -> Optional[executor.ExecuteResult]:
33+
node = rewrite.try_reduce_to_local_scan(plan)
34+
if not node:
35+
return None
36+
37+
# TODO: Can support some slicing, sorting
38+
def iterator_supplier():
39+
offsets_col = (
40+
node.offsets_col.sql if (node.offsets_col is not None) else None
41+
)
42+
arrow_table = node.local_data_source.to_arrow(offsets_col=offsets_col)
43+
if peek:
44+
arrow_table = arrow_table.slice(0, peek)
45+
for batch in arrow_table.slice(0, peek):
46+
batch = batch.select([item.source_id for item in node.scan_list.items])
47+
batch = batch.rename_columns(
48+
{item.source_id, item.id.sql} for item in node.scan_list.items
49+
)
50+
yield batch
51+
52+
return executor.ExecuteResult(
53+
arrow_batches=iterator_supplier,
54+
schema=plan.schema,
55+
query_job=None,
56+
total_bytes=None,
57+
total_rows=peek or node.local_data_source.metadata.row_count,
58+
)

0 commit comments

Comments
 (0)