Skip to content

Commit 943c0ad

Browse files
test: Run some engine tests on sqlglot compiler
1 parent 1aa7950 commit 943c0ad

File tree

5 files changed

+26
-11
lines changed

5 files changed

+26
-11
lines changed

bigframes/session/direct_gbq_execution.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Optional, Tuple
16+
from typing import Literal, Optional, Tuple
1717

1818
from google.cloud import bigquery
1919
import google.cloud.bigquery.job as bq_job
2020
import google.cloud.bigquery.table as bq_table
2121

2222
from bigframes.core import compile, nodes
23+
from bigframes.core.compile import sqlglot
2324
from bigframes.session import executor, semi_executor
2425
import bigframes.session._io.bigquery as bq_io
2526

@@ -29,8 +30,15 @@
2930
# or record metrics. Also avoids caching, and most pre-compile rewrites, to better serve as a
3031
# reference for validating more complex executors.
3132
class DirectGbqExecutor(semi_executor.SemiExecutor):
32-
def __init__(self, bqclient: bigquery.Client):
33+
def __init__(
34+
self, bqclient: bigquery.Client, compiler: Literal["ibis", "sqlglot"] = "ibis"
35+
):
3336
self.bqclient = bqclient
37+
self._compile_fn = (
38+
compile.compile_sql
39+
if compiler == "ibis"
40+
else sqlglot.SQLGlotCompiler()._compile_sql
41+
)
3442

3543
def execute(
3644
self,
@@ -42,9 +50,10 @@ def execute(
4250
# TODO(swast): plumb through the api_name of the user-facing api that
4351
# caused this query.
4452

45-
compiled = compile.compile_sql(
53+
compiled = self._compile_fn(
4654
compile.CompileRequest(plan, sort_rows=ordered, peek_count=peek)
4755
)
56+
4857
iterator, query_job = self._run_execute_query(
4958
sql=compiled.sql,
5059
)

tests/system/small/engines/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,18 @@ def fake_session() -> Generator[bigframes.Session, None, None]:
4444
yield session
4545

4646

47-
@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq"])
47+
@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot"])
4848
def engine(request, bigquery_client: bigquery.Client) -> semi_executor.SemiExecutor:
4949
if request.param == "pyarrow":
5050
return local_scan_executor.LocalScanExecutor()
5151
if request.param == "polars":
5252
return polars_executor.PolarsExecutor()
5353
if request.param == "bq":
5454
return direct_gbq_execution.DirectGbqExecutor(bigquery_client)
55+
if request.param == "bq-sqlglot":
56+
return direct_gbq_execution.DirectGbqExecutor(
57+
bigquery_client, compiler="sqlglot"
58+
)
5559
raise ValueError(f"Unrecognized param: {request.param}")
5660

5761

tests/system/small/engines/test_join.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
2828

2929

30-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
30+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
3131
@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"])
3232
def test_engines_join_on_key(
3333
scalars_array_value: array_value.ArrayValue,
@@ -41,7 +41,7 @@ def test_engines_join_on_key(
4141
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
4242

4343

44-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
44+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
4545
@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"])
4646
def test_engines_join_on_coerced_key(
4747
scalars_array_value: array_value.ArrayValue,
@@ -80,7 +80,7 @@ def test_engines_join_multi_key(
8080
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
8181

8282

83-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
83+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
8484
def test_engines_cross_join(
8585
scalars_array_value: array_value.ArrayValue,
8686
engine,

tests/system/small/engines/test_read_local.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def test_engines_read_local_w_zero_row_source(
8888
assert_equivalence_execution(local_node, REFERENCE_ENGINE, engine)
8989

9090

91+
# TODO: Fix sqlglot impl
92+
@pytest.mark.parametrize("engine", ["polars", "bq", "pyarrow"], indirect=True)
9193
def test_engines_read_local_w_nested_source(
9294
fake_session: bigframes.Session,
9395
nested_data_source: local_data.ManagedArrowTable,

tests/system/small/engines/test_sorting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
2626

2727

28-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
28+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
2929
def test_engines_reverse(
3030
scalars_array_value: array_value.ArrayValue,
3131
engine,
@@ -34,7 +34,7 @@ def test_engines_reverse(
3434
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
3535

3636

37-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
37+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
3838
def test_engines_double_reverse(
3939
scalars_array_value: array_value.ArrayValue,
4040
engine,
@@ -43,7 +43,7 @@ def test_engines_double_reverse(
4343
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
4444

4545

46-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
46+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
4747
@pytest.mark.parametrize(
4848
"sort_col",
4949
[
@@ -70,7 +70,7 @@ def test_engines_sort_over_column(
7070
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
7171

7272

73-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
73+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
7474
def test_engines_sort_multi_column_refs(
7575
scalars_array_value: array_value.ArrayValue,
7676
engine,

0 commit comments

Comments
 (0)