Skip to content

Commit 6adb0e1

Browse files
authored
Merge branch 'main' into main_chelsealin_enablesqlglot2
2 parents f900cdb + d02d32f commit 6adb0e1

File tree

13 files changed

+147
-92
lines changed

13 files changed

+147
-92
lines changed

bigframes/core/blocks.py

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -818,49 +818,30 @@ def _materialize_local(
818818
total_rows = result_batches.approx_total_rows
819819
# Remove downsampling config from subsequent invocations, as otherwise could result in many
820820
# iterations if downsampling undershoots
821-
return self._downsample(
822-
total_rows=total_rows,
823-
sampling_method=sample_config.sampling_method,
824-
fraction=fraction,
825-
random_state=sample_config.random_state,
826-
)._materialize_local(
827-
MaterializationOptions(ordered=materialize_options.ordered)
828-
)
829-
else:
830-
df = result_batches.to_pandas()
831-
df = self._copy_index_to_pandas(df)
832-
df.set_axis(self.column_labels, axis=1, copy=False)
833-
return df, execute_result.query_job
834-
835-
def _downsample(
836-
self, total_rows: int, sampling_method: str, fraction: float, random_state
837-
) -> Block:
838-
# either selecting fraction or number of rows
839-
if sampling_method == _HEAD:
840-
filtered_block = self.slice(stop=int(total_rows * fraction))
841-
return filtered_block
842-
elif (sampling_method == _UNIFORM) and (random_state is None):
843-
filtered_expr = self.expr._uniform_sampling(fraction)
844-
block = Block(
845-
filtered_expr,
846-
index_columns=self.index_columns,
847-
column_labels=self.column_labels,
848-
index_labels=self.index.names,
849-
)
850-
return block
851-
elif sampling_method == _UNIFORM:
852-
block = self.split(
853-
fracs=(fraction,),
854-
random_state=random_state,
855-
sort=False,
856-
)[0]
857-
return block
821+
if sample_config.sampling_method == "head":
822+
# Just truncates the result iterator without a follow-up query
823+
raw_df = result_batches.to_pandas(limit=int(total_rows * fraction))
824+
elif (
825+
sample_config.sampling_method == "uniform"
826+
and sample_config.random_state is None
827+
):
828+
# Pushes sample into result without new query
829+
sampled_batches = execute_result.batches(sample_rate=fraction)
830+
raw_df = sampled_batches.to_pandas()
831+
else: # uniform sample with random state requires a full follow-up query
832+
down_sampled_block = self.split(
833+
fracs=(fraction,),
834+
random_state=sample_config.random_state,
835+
sort=False,
836+
)[0]
837+
return down_sampled_block._materialize_local(
838+
MaterializationOptions(ordered=materialize_options.ordered)
839+
)
858840
else:
859-
# This part should never be called, just in case.
860-
raise NotImplementedError(
861-
f"The downsampling method {sampling_method} is not implemented, "
862-
f"please choose from {','.join(_SAMPLING_METHODS)}."
863-
)
841+
raw_df = result_batches.to_pandas()
842+
df = self._copy_index_to_pandas(raw_df)
843+
df.set_axis(self.column_labels, axis=1, copy=False)
844+
return df, execute_result.query_job
864845

865846
def split(
866847
self,

bigframes/core/bq_data.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,22 @@ def get_arrow_batches(
186186
columns: Sequence[str],
187187
storage_read_client: bigquery_storage_v1.BigQueryReadClient,
188188
project_id: str,
189+
sample_rate: Optional[float] = None,
189190
) -> ReadResult:
190191
table_mod_options = {}
191192
read_options_dict: dict[str, Any] = {"selected_fields": list(columns)}
193+
194+
predicates = []
192195
if data.sql_predicate:
193-
read_options_dict["row_restriction"] = data.sql_predicate
196+
predicates.append(data.sql_predicate)
197+
if sample_rate is not None:
198+
assert isinstance(sample_rate, float)
199+
predicates.append(f"RAND() < {sample_rate}")
200+
201+
if predicates:
202+
full_predicates = " AND ".join(f"( {pred} )" for pred in predicates)
203+
read_options_dict["row_restriction"] = full_predicates
204+
194205
read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options_dict)
195206

196207
if data.at_time:

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,19 @@ def _(expr: TypedExpr) -> sge.Expression:
9393
def _(expr: TypedExpr) -> sge.Expression:
9494
return sge.Case(
9595
ifs=[
96+
# |x| < 1: The standard formula
97+
sge.If(
98+
this=sge.func("ABS", expr.expr) < sge.convert(1),
99+
true=sge.func("ATANH", expr.expr),
100+
),
101+
# |x| > 1: Returns NaN
96102
sge.If(
97103
this=sge.func("ABS", expr.expr) > sge.convert(1),
98104
true=constants._NAN,
99-
)
105+
),
100106
],
101-
default=sge.func("ATANH", expr.expr),
107+
# |x| = 1: Returns Infinity or -Infinity
108+
default=sge.Mul(this=constants._INF, expression=expr.expr),
102109
)
103110

104111

@@ -145,15 +152,11 @@ def _(expr: TypedExpr) -> sge.Expression:
145152

146153
@register_unary_op(ops.expm1_op)
147154
def _(expr: TypedExpr) -> sge.Expression:
148-
return sge.Case(
149-
ifs=[
150-
sge.If(
151-
this=expr.expr > constants._FLOAT64_EXP_BOUND,
152-
true=constants._INF,
153-
)
154-
],
155-
default=sge.func("EXP", expr.expr),
156-
) - sge.convert(1)
155+
return sge.If(
156+
this=expr.expr > constants._FLOAT64_EXP_BOUND,
157+
true=constants._INF,
158+
false=sge.func("EXP", expr.expr) - sge.convert(1),
159+
)
157160

158161

159162
@register_unary_op(ops.floor_op)
@@ -166,11 +169,22 @@ def _(expr: TypedExpr) -> sge.Expression:
166169
return sge.Case(
167170
ifs=[
168171
sge.If(
169-
this=expr.expr <= sge.convert(0),
172+
this=sge.Is(this=expr.expr, expression=sge.Null()),
173+
true=sge.null(),
174+
),
175+
# |x| > 0: The standard formula
176+
sge.If(
177+
this=expr.expr > sge.convert(0),
178+
true=sge.Ln(this=expr.expr),
179+
),
180+
# |x| < 0: Returns NaN
181+
sge.If(
182+
this=expr.expr < sge.convert(0),
170183
true=constants._NAN,
171-
)
184+
),
172185
],
173-
default=sge.Ln(this=expr.expr),
186+
# |x| == 0: Returns -Infinity
187+
default=constants._NEG_INF,
174188
)
175189

176190

@@ -179,11 +193,22 @@ def _(expr: TypedExpr) -> sge.Expression:
179193
return sge.Case(
180194
ifs=[
181195
sge.If(
182-
this=expr.expr <= sge.convert(0),
196+
this=sge.Is(this=expr.expr, expression=sge.Null()),
197+
true=sge.null(),
198+
),
199+
# |x| > 0: The standard formula
200+
sge.If(
201+
this=expr.expr > sge.convert(0),
202+
true=sge.Log(this=sge.convert(10), expression=expr.expr),
203+
),
204+
# |x| < 0: Returns NaN
205+
sge.If(
206+
this=expr.expr < sge.convert(0),
183207
true=constants._NAN,
184-
)
208+
),
185209
],
186-
default=sge.Log(this=expr.expr, expression=sge.convert(10)),
210+
# |x| == 0: Returns -Infinity
211+
default=constants._NEG_INF,
187212
)
188213

189214

@@ -192,11 +217,22 @@ def _(expr: TypedExpr) -> sge.Expression:
192217
return sge.Case(
193218
ifs=[
194219
sge.If(
195-
this=expr.expr <= sge.convert(-1),
220+
this=sge.Is(this=expr.expr, expression=sge.Null()),
221+
true=sge.null(),
222+
),
223+
# Domain: |x| > -1 (The standard formula)
224+
sge.If(
225+
this=expr.expr > sge.convert(-1),
226+
true=sge.Ln(this=sge.convert(1) + expr.expr),
227+
),
228+
# Out of Domain: |x| < -1 (Returns NaN)
229+
sge.If(
230+
this=expr.expr < sge.convert(-1),
196231
true=constants._NAN,
197-
)
232+
),
198233
],
199-
default=sge.Ln(this=sge.convert(1) + expr.expr),
234+
# Boundary: |x| == -1 (Returns -Infinity)
235+
default=constants._NEG_INF,
200236
)
201237

202238

@@ -608,7 +644,7 @@ def isfinite(arg: TypedExpr) -> sge.Expression:
608644
return sge.Not(
609645
this=sge.Or(
610646
this=sge.IsInf(this=arg.expr),
611-
right=sge.IsNan(this=arg.expr),
647+
expression=sge.IsNan(this=arg.expr),
612648
),
613649
)
614650

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
674674
expressions=[_literal(value=v, dtype=value_type) for v in value]
675675
)
676676
return values if len(value) > 0 else _cast(values, sqlglot_type)
677-
elif pd.isna(value):
677+
elif pd.isna(value) or (isinstance(value, pa.Scalar) and not value.is_valid):
678678
return _cast(sge.Null(), sqlglot_type)
679679
elif dtype == dtypes.JSON_DTYPE:
680680
return sge.ParseJSON(this=sge.convert(str(value)))

bigframes/core/local_data.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import uuid
2626

2727
import geopandas # type: ignore
28+
import numpy
2829
import numpy as np
2930
import pandas as pd
3031
import pyarrow as pa
@@ -124,13 +125,21 @@ def to_arrow(
124125
geo_format: Literal["wkb", "wkt"] = "wkt",
125126
duration_type: Literal["int", "duration"] = "duration",
126127
json_type: Literal["string"] = "string",
128+
sample_rate: Optional[float] = None,
127129
max_chunksize: Optional[int] = None,
128130
) -> tuple[pa.Schema, Iterable[pa.RecordBatch]]:
129131
if geo_format != "wkt":
130132
raise NotImplementedError(f"geo format {geo_format} not yet implemented")
131133
assert json_type == "string"
132134

133-
batches = self.data.to_batches(max_chunksize=max_chunksize)
135+
data = self.data
136+
137+
# This exists for symmetry with remote sources, but sampling local data like this shouldn't really happen
138+
if sample_rate is not None:
139+
to_take = numpy.random.rand(data.num_rows) < sample_rate
140+
data = data.filter(to_take)
141+
142+
batches = data.to_batches(max_chunksize=max_chunksize)
134143
schema = self.data.schema
135144
if duration_type == "int":
136145
schema = _schema_durations_to_ints(schema)

bigframes/session/executor.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]:
8888

8989
yield batch
9090

91-
def to_arrow_table(self) -> pyarrow.Table:
91+
def to_arrow_table(self, limit: Optional[int] = None) -> pyarrow.Table:
9292
# Need to provide schema if no result rows, as arrow can't infer
9393
# If ther are rows, it is safest to infer schema from batches.
9494
# Any discrepencies between predicted schema and actual schema will produce errors.
@@ -97,18 +97,21 @@ def to_arrow_table(self) -> pyarrow.Table:
9797
peek_value = list(peek_it)
9898
# TODO: Enforce our internal schema on the table for consistency
9999
if len(peek_value) > 0:
100-
return pyarrow.Table.from_batches(
101-
itertools.chain(peek_value, batches), # reconstruct
102-
)
100+
batches = itertools.chain(peek_value, batches) # reconstruct
101+
if limit:
102+
batches = pyarrow_utils.truncate_pyarrow_iterable(
103+
batches, max_results=limit
104+
)
105+
return pyarrow.Table.from_batches(batches)
103106
else:
104107
try:
105108
return self._schema.to_pyarrow().empty_table()
106109
except pa.ArrowNotImplementedError:
107110
# Bug with some pyarrow versions, empty_table only supports base storage types, not extension types.
108111
return self._schema.to_pyarrow(use_storage_types=True).empty_table()
109112

110-
def to_pandas(self) -> pd.DataFrame:
111-
return io_pandas.arrow_to_pandas(self.to_arrow_table(), self._schema)
113+
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
114+
return io_pandas.arrow_to_pandas(self.to_arrow_table(limit=limit), self._schema)
112115

113116
def to_pandas_batches(
114117
self, page_size: Optional[int] = None, max_results: Optional[int] = None
@@ -158,7 +161,7 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
158161
...
159162

160163
@abc.abstractmethod
161-
def batches(self) -> ResultsIterator:
164+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
162165
...
163166

164167
@property
@@ -200,9 +203,9 @@ def execution_metadata(self) -> ExecutionMetadata:
200203
def schema(self) -> bigframes.core.schema.ArraySchema:
201204
return self._data.schema
202205

203-
def batches(self) -> ResultsIterator:
206+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
204207
return ResultsIterator(
205-
iter(self._data.to_arrow()[1]),
208+
iter(self._data.to_arrow(sample_rate=sample_rate)[1]),
206209
self.schema,
207210
self._data.metadata.row_count,
208211
self._data.metadata.total_bytes,
@@ -226,7 +229,7 @@ def execution_metadata(self) -> ExecutionMetadata:
226229
def schema(self) -> bigframes.core.schema.ArraySchema:
227230
return self._schema
228231

229-
def batches(self) -> ResultsIterator:
232+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
230233
return ResultsIterator(iter([]), self.schema, 0, 0)
231234

232235

@@ -260,12 +263,13 @@ def schema(self) -> bigframes.core.schema.ArraySchema:
260263
source_ids = [selection[0] for selection in self._selected_fields]
261264
return self._data.schema.select(source_ids).rename(dict(self._selected_fields))
262265

263-
def batches(self) -> ResultsIterator:
266+
def batches(self, sample_rate: Optional[float] = None) -> ResultsIterator:
264267
read_batches = bq_data.get_arrow_batches(
265268
self._data,
266269
[x[0] for x in self._selected_fields],
267270
self._storage_client,
268271
self._project_id,
272+
sample_rate=sample_rate,
269273
)
270274
arrow_batches: Iterator[pa.RecordBatch] = map(
271275
functools.partial(

tests/system/small/test_anywidget.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def execution_metadata(self) -> ExecutionMetadata:
165165
def schema(self) -> Any:
166166
return schema
167167

168-
def batches(self) -> ResultsIterator:
168+
def batches(self, sample_rate=None) -> ResultsIterator:
169169
return ResultsIterator(
170170
arrow_batches_val,
171171
self.schema,

tests/system/small/test_dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4524,7 +4524,7 @@ def test_df_kurt(scalars_dfs):
45244524
"n_default",
45254525
],
45264526
)
4527-
def test_sample(scalars_dfs, frac, n, random_state):
4527+
def test_df_to_pandas_sample(scalars_dfs, frac, n, random_state):
45284528
scalars_df, _ = scalars_dfs
45294529
df = scalars_df.sample(frac=frac, n=n, random_state=random_state)
45304530
bf_result = df.to_pandas()
@@ -4535,15 +4535,15 @@ def test_sample(scalars_dfs, frac, n, random_state):
45354535
assert bf_result.shape[1] == scalars_df.shape[1]
45364536

45374537

4538-
def test_sample_determinism(penguins_df_default_index):
4538+
def test_df_to_pandas_sample_determinism(penguins_df_default_index):
45394539
df = penguins_df_default_index.sample(n=100, random_state=12345).head(15)
45404540
bf_result = df.to_pandas()
45414541
bf_result2 = df.to_pandas()
45424542

45434543
pandas.testing.assert_frame_equal(bf_result, bf_result2)
45444544

45454545

4546-
def test_sample_raises_value_error(scalars_dfs):
4546+
def test_df_to_pandas_sample_raises_value_error(scalars_dfs):
45474547
scalars_df, _ = scalars_dfs
45484548
with pytest.raises(
45494549
ValueError, match="Only one of 'n' or 'frac' parameter can be specified."

tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_arctanh/out.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ WITH `bfcte_0` AS (
66
SELECT
77
*,
88
CASE
9+
WHEN ABS(`float64_col`) < 1
10+
THEN ATANH(`float64_col`)
911
WHEN ABS(`float64_col`) > 1
1012
THEN CAST('NaN' AS FLOAT64)
11-
ELSE ATANH(`float64_col`)
13+
ELSE CAST('Infinity' AS FLOAT64) * `float64_col`
1214
END AS `bfcol_1`
1315
FROM `bfcte_0`
1416
)

0 commit comments

Comments
 (0)