-
Notifications
You must be signed in to change notification settings - Fork 49
feat: Short circuit query for local scan #1618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a89e770
19f2ad0
6075f58
a2588cb
202903c
2553a0b
5fbf4ab
f2015bd
e388d07
384c19c
19f2cab
96bc031
b1e70c6
015a853
3fda71e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,6 @@ | |
import pyarrow.parquet # type: ignore | ||
|
||
import bigframes.core.schema as schemata | ||
import bigframes.core.utils as utils | ||
import bigframes.dtypes | ||
|
||
|
||
|
@@ -79,7 +78,7 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> ManagedArrowTable: | |
mat = ManagedArrowTable( | ||
pa.table(columns, names=column_names), schemata.ArraySchema(tuple(fields)) | ||
) | ||
mat.validate(include_content=True) | ||
mat.validate() | ||
return mat | ||
|
||
@classmethod | ||
|
@@ -98,15 +97,14 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable: | |
mat.validate() | ||
return mat | ||
|
||
def to_parquet( | ||
def to_pyarrow_table( | ||
self, | ||
dst: Union[str, io.IOBase], | ||
*, | ||
offsets_col: Optional[str] = None, | ||
geo_format: Literal["wkb", "wkt"] = "wkt", | ||
duration_type: Literal["int", "duration"] = "duration", | ||
json_type: Literal["string"] = "string", | ||
): | ||
) -> pa.Table: | ||
pa_table = self.data | ||
if offsets_col is not None: | ||
pa_table = pa_table.append_column( | ||
|
@@ -119,6 +117,23 @@ def to_parquet( | |
f"duration as {duration_type} not yet implemented" | ||
) | ||
assert json_type == "string" | ||
return pa_table | ||
|
||
def to_parquet( | ||
self, | ||
dst: Union[str, io.IOBase], | ||
*, | ||
offsets_col: Optional[str] = None, | ||
geo_format: Literal["wkb", "wkt"] = "wkt", | ||
duration_type: Literal["int", "duration"] = "duration", | ||
json_type: Literal["string"] = "string", | ||
): | ||
pa_table = self.to_pyarrow_table( | ||
offsets_col=offsets_col, | ||
geo_format=geo_format, | ||
duration_type=duration_type, | ||
json_type=json_type, | ||
) | ||
pyarrow.parquet.write_table(pa_table, where=dst) | ||
|
||
def itertuples( | ||
|
@@ -142,7 +157,7 @@ def itertuples( | |
): | ||
yield tuple(row_dict.values()) | ||
|
||
def validate(self, include_content: bool = False): | ||
def validate(self): | ||
for bf_field, arrow_field in zip(self.schema.items, self.data.schema): | ||
expected_arrow_type = _get_managed_storage_type(bf_field.dtype) | ||
arrow_type = arrow_field.type | ||
|
@@ -151,38 +166,6 @@ def validate(self, include_content: bool = False): | |
f"Field {bf_field} has arrow array type: {arrow_type}, expected type: {expected_arrow_type}" | ||
) | ||
|
||
if include_content: | ||
for batch in self.data.to_batches(): | ||
for field in self.schema.items: | ||
_validate_content(batch.column(field.column), field.dtype) | ||
|
||
|
||
def _validate_content(array: pa.Array, dtype: bigframes.dtypes.Dtype): | ||
""" | ||
Recursively validates the content of a PyArrow Array based on the | ||
expected BigFrames dtype, focusing on complex types like JSON, structs, | ||
and arrays where the Arrow type alone isn't sufficient. | ||
""" | ||
# TODO: validate GEO data context. | ||
if dtype == bigframes.dtypes.JSON_DTYPE: | ||
values = array.to_pandas() | ||
for data in values: | ||
# Skip scalar null values to avoid `TypeError` from json.load. | ||
if not utils.is_list_like(data) and pd.isna(data): | ||
continue | ||
try: | ||
# Attempts JSON parsing. | ||
json.loads(data) | ||
except json.JSONDecodeError as e: | ||
raise ValueError(f"Invalid JSON format found: {data!r}") from e | ||
elif bigframes.dtypes.is_struct_like(dtype): | ||
for field_name, dtype in bigframes.dtypes.get_struct_fields(dtype).items(): | ||
_validate_content(array.field(field_name), dtype) | ||
elif bigframes.dtypes.is_array_like(dtype): | ||
return _validate_content( | ||
array.flatten(), bigframes.dtypes.get_array_inner_type(dtype) | ||
) | ||
|
||
|
||
# Sequential iterator, but could split into batches and leverage parallelism for speed | ||
def _iter_table( | ||
|
@@ -280,6 +263,34 @@ def _adapt_pandas_series( | |
def _adapt_arrow_array( | ||
array: Union[pa.ChunkedArray, pa.Array] | ||
) -> tuple[Union[pa.ChunkedArray, pa.Array], bigframes.dtypes.Dtype]: | ||
"""Normalize the array to managed storage types. Preverse shapes, only transforms values.""" | ||
if pa.types.is_struct(array.type): | ||
assert isinstance(array, pa.StructArray) | ||
assert isinstance(array.type, pa.StructType) | ||
arrays = [] | ||
dtypes = [] | ||
pa_fields = [] | ||
for i in range(array.type.num_fields): | ||
field_array, field_type = _adapt_arrow_array(array.field(i)) | ||
arrays.append(field_array) | ||
dtypes.append(field_type) | ||
pa_fields.append(pa.field(array.type.field(i).name, field_array.type)) | ||
struct_array = pa.StructArray.from_arrays( | ||
arrays=arrays, fields=pa_fields, mask=array.is_null() | ||
) | ||
dtype = bigframes.dtypes.struct_type( | ||
[(field.name, dtype) for field, dtype in zip(pa_fields, dtypes)] | ||
) | ||
return struct_array, dtype | ||
if pa.types.is_list(array.type): | ||
assert isinstance(array, pa.ListArray) | ||
values, values_type = _adapt_arrow_array(array.values) | ||
new_value = pa.ListArray.from_arrays( | ||
array.offsets, values, mask=array.is_null() | ||
) | ||
return new_value.fill_null([]), bigframes.dtypes.list_type(values_type) | ||
if array.type == bigframes.dtypes.JSON_ARROW_TYPE: | ||
return _canonicalize_json(array), bigframes.dtypes.JSON_DTYPE | ||
target_type = _logical_type_replacements(array.type) | ||
if target_type != array.type: | ||
# TODO: Maybe warn if lossy conversion? | ||
|
@@ -292,6 +303,22 @@ def _adapt_arrow_array( | |
return array, bf_type | ||
|
||
|
||
def _canonicalize_json(array: pa.Array) -> pa.Array: | ||
def _canonicalize_scalar(json_string): | ||
if json_string is None: | ||
return None | ||
# This is the canonical form that bq uses when emitting json | ||
# The sorted keys and unambiguous whitespace ensures a 1:1 mapping | ||
# between syntax and semantics. | ||
return json.dumps( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add comments why we need to canonicalize json here? Refer to: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added comment |
||
json.loads(json_string), sort_keys=True, separators=(",", ":") | ||
) | ||
|
||
return pa.array( | ||
[_canonicalize_scalar(value) for value in array.to_pylist()], type=pa.string() | ||
) | ||
|
||
|
||
def _get_managed_storage_type(dtype: bigframes.dtypes.Dtype) -> pa.DataType: | ||
if dtype in _MANAGED_STORAGE_TYPES_OVERRIDES.keys(): | ||
return _MANAGED_STORAGE_TYPES_OVERRIDES[dtype] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright 2025 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
from typing import Optional | ||
|
||
from bigframes.core import bigframe_node, rewrite | ||
from bigframes.session import executor, semi_executor | ||
|
||
|
||
class LocalScanExecutor(semi_executor.SemiExecutor): | ||
""" | ||
Executes plans reducible to a arrow table scan. | ||
""" | ||
|
||
def execute( | ||
self, | ||
plan: bigframe_node.BigFrameNode, | ||
ordered: bool, | ||
peek: Optional[int] = None, | ||
) -> Optional[executor.ExecuteResult]: | ||
node = rewrite.try_reduce_to_local_scan(plan) | ||
if not node: | ||
return None | ||
|
||
# TODO: Can support some slicing, sorting | ||
def iterator_supplier(): | ||
offsets_col = ( | ||
node.offsets_col.sql if (node.offsets_col is not None) else None | ||
) | ||
arrow_table = node.local_data_source.to_pyarrow_table( | ||
offsets_col=offsets_col | ||
) | ||
if peek: | ||
arrow_table = arrow_table.slice(0, peek) | ||
|
||
needed_cols = [item.source_id for item in node.scan_list.items] | ||
if offsets_col is not None: | ||
needed_cols.append(offsets_col) | ||
|
||
arrow_table = arrow_table.select(needed_cols) | ||
arrow_table = arrow_table.rename_columns( | ||
{item.source_id: item.id.sql for item in node.scan_list.items} | ||
) | ||
yield from arrow_table.to_batches() | ||
|
||
total_rows = node.row_count | ||
if (peek is not None) and (total_rows is not None): | ||
total_rows = min(peek, total_rows) | ||
|
||
return executor.ExecuteResult( | ||
arrow_batches=iterator_supplier, | ||
schema=plan.schema, | ||
query_job=None, | ||
total_bytes=None, | ||
total_rows=total_rows, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With
_canonicalize_json
, you probably can remove_validate_content
method in this moduleThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed