Skip to content

Commit 23754a8

Browse files
committed
Merge branch 'main' into shuowei-anywidget-deferred-mode
2 parents 876f8ae + ee83d98 commit 23754a8

File tree

3 files changed

+69
-17
lines changed

3 files changed

+69
-17
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,17 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
152152
return sge.Coalesce(this=left.expr, expressions=[right.expr])
153153

154154

155-
@register_unary_op(ops.RemoteFunctionOp, pass_op=True)
156-
def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression:
155+
def _get_remote_function_name(op):
157156
routine_ref = op.function_def.routine_ref
158157
# Quote project, dataset, and routine IDs to avoid keyword clashes.
159-
func_name = (
158+
return (
160159
f"`{routine_ref.project}`.`{routine_ref.dataset_id}`.`{routine_ref.routine_id}`"
161160
)
161+
162+
163+
@register_unary_op(ops.RemoteFunctionOp, pass_op=True)
164+
def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression:
165+
func_name = _get_remote_function_name(op)
162166
func = sge.func(func_name, expr.expr)
163167

164168
if not op.apply_on_null:
@@ -175,15 +179,16 @@ def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression:
175179
def _(
176180
left: TypedExpr, right: TypedExpr, op: ops.BinaryRemoteFunctionOp
177181
) -> sge.Expression:
178-
routine_ref = op.function_def.routine_ref
179-
# Quote project, dataset, and routine IDs to avoid keyword clashes.
180-
func_name = (
181-
f"`{routine_ref.project}`.`{routine_ref.dataset_id}`.`{routine_ref.routine_id}`"
182-
)
183-
182+
func_name = _get_remote_function_name(op)
184183
return sge.func(func_name, left.expr, right.expr)
185184

186185

186+
@register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True)
187+
def _(*operands: TypedExpr, op: ops.NaryRemoteFunctionOp) -> sge.Expression:
188+
func_name = _get_remote_function_name(op)
189+
return sge.func(func_name, *(operand.expr for operand in operands))
190+
191+
187192
@register_nary_op(ops.case_when_op)
188193
def _(*cases_and_outputs: TypedExpr) -> sge.Expression:
189194
# Need to upcast BOOL to INT if any output is numeric
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col`,
4+
`int64_col`,
5+
`string_col`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`my_project`.`my_dataset`.`my_routine`(`int64_col`, `float64_col`, `string_col`) AS `bfcol_3`
11+
FROM `bfcte_0`
12+
)
13+
SELECT
14+
`bfcol_3` AS `int64_col`
15+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from google.cloud import bigquery
1516
import pandas as pd
1617
import pytest
1718

1819
from bigframes import dtypes
1920
from bigframes import operations as ops
2021
from bigframes.core import expression as ex
22+
from bigframes.functions import udf_def
2123
import bigframes.pandas as bpd
2224
from bigframes.testing import utils
2325

@@ -170,10 +172,6 @@ def test_astype_json_invalid(
170172

171173

172174
def test_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot):
173-
from google.cloud import bigquery
174-
175-
from bigframes.functions import udf_def
176-
177175
bf_df = scalar_types_df[["int64_col"]]
178176
function_def = udf_def.BigqueryUdf(
179177
routine_ref=bigquery.RoutineReference.from_string(
@@ -206,10 +204,6 @@ def test_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot):
206204

207205

208206
def test_binary_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot):
209-
from google.cloud import bigquery
210-
211-
from bigframes.functions import udf_def
212-
213207
bf_df = scalar_types_df[["int64_col", "float64_col"]]
214208
op = ops.BinaryRemoteFunctionOp(
215209
function_def=udf_def.BigqueryUdf(
@@ -242,6 +236,44 @@ def test_binary_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot):
242236
snapshot.assert_match(sql, "out.sql")
243237

244238

239+
def test_nary_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot):
240+
bf_df = scalar_types_df[["int64_col", "float64_col", "string_col"]]
241+
op = ops.NaryRemoteFunctionOp(
242+
function_def=udf_def.BigqueryUdf(
243+
routine_ref=bigquery.RoutineReference.from_string(
244+
"my_project.my_dataset.my_routine"
245+
),
246+
signature=udf_def.UdfSignature(
247+
input_types=(
248+
udf_def.UdfField(
249+
"x",
250+
bigquery.StandardSqlDataType(
251+
type_kind=bigquery.StandardSqlTypeNames.INT64
252+
),
253+
),
254+
udf_def.UdfField(
255+
"y",
256+
bigquery.StandardSqlDataType(
257+
type_kind=bigquery.StandardSqlTypeNames.FLOAT64
258+
),
259+
),
260+
udf_def.UdfField(
261+
"z",
262+
bigquery.StandardSqlDataType(
263+
type_kind=bigquery.StandardSqlTypeNames.STRING
264+
),
265+
),
266+
),
267+
output_bq_type=bigquery.StandardSqlDataType(
268+
type_kind=bigquery.StandardSqlTypeNames.FLOAT64
269+
),
270+
),
271+
)
272+
)
273+
sql = utils._apply_nary_op(bf_df, op, "int64_col", "float64_col", "string_col")
274+
snapshot.assert_match(sql, "out.sql")
275+
276+
245277
def test_case_when_op(scalar_types_df: bpd.DataFrame, snapshot):
246278
ops_map = {
247279
"single_case": ops.case_when_op.as_expr(

0 commit comments

Comments
 (0)