Skip to content

test: Run some engine tests on sqlglot compiler #1903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions bigframes/session/direct_gbq_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.
from __future__ import annotations

from typing import Optional, Tuple
from typing import Literal, Optional, Tuple

from google.cloud import bigquery
import google.cloud.bigquery.job as bq_job
import google.cloud.bigquery.table as bq_table

from bigframes.core import compile, nodes
from bigframes.core.compile import sqlglot
from bigframes.session import executor, semi_executor
import bigframes.session._io.bigquery as bq_io

Expand All @@ -29,8 +30,15 @@
# or record metrics. Also avoids caching, and most pre-compile rewrites, to better serve as a
# reference for validating more complex executors.
class DirectGbqExecutor(semi_executor.SemiExecutor):
def __init__(self, bqclient: bigquery.Client):
def __init__(
self, bqclient: bigquery.Client, compiler: Literal["ibis", "sqlglot"] = "ibis"
):
self.bqclient = bqclient
self._compile_fn = (
compile.compile_sql
if compiler == "ibis"
else sqlglot.SQLGlotCompiler()._compile_sql
)

def execute(
self,
Expand All @@ -42,9 +50,10 @@ def execute(
# TODO(swast): plumb through the api_name of the user-facing api that
# caused this query.

compiled = compile.compile_sql(
compiled = self._compile_fn(
compile.CompileRequest(plan, sort_rows=ordered, peek_count=peek)
)

iterator, query_job = self._run_execute_query(
sql=compiled.sql,
)
Expand Down
6 changes: 5 additions & 1 deletion tests/system/small/engines/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,18 @@ def fake_session() -> Generator[bigframes.Session, None, None]:
yield session


@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq"])
@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot"])
def engine(request, bigquery_client: bigquery.Client) -> semi_executor.SemiExecutor:
if request.param == "pyarrow":
return local_scan_executor.LocalScanExecutor()
if request.param == "polars":
return polars_executor.PolarsExecutor()
if request.param == "bq":
return direct_gbq_execution.DirectGbqExecutor(bigquery_client)
if request.param == "bq-sqlglot":
return direct_gbq_execution.DirectGbqExecutor(
bigquery_client, compiler="sqlglot"
)
raise ValueError(f"Unrecognized param: {request.param}")


Expand Down
6 changes: 3 additions & 3 deletions tests/system/small/engines/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
REFERENCE_ENGINE = polars_executor.PolarsExecutor()


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"])
def test_engines_join_on_key(
scalars_array_value: array_value.ArrayValue,
Expand All @@ -41,7 +41,7 @@ def test_engines_join_on_key(
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"])
def test_engines_join_on_coerced_key(
scalars_array_value: array_value.ArrayValue,
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_engines_join_multi_key(
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_cross_join(
scalars_array_value: array_value.ArrayValue,
engine,
Expand Down
2 changes: 2 additions & 0 deletions tests/system/small/engines/test_read_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def test_engines_read_local_w_zero_row_source(
assert_equivalence_execution(local_node, REFERENCE_ENGINE, engine)


# TODO: Fix sqlglot impl
@pytest.mark.parametrize("engine", ["polars", "bq", "pyarrow"], indirect=True)
def test_engines_read_local_w_nested_source(
fake_session: bigframes.Session,
nested_data_source: local_data.ManagedArrowTable,
Expand Down
8 changes: 4 additions & 4 deletions tests/system/small/engines/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
REFERENCE_ENGINE = polars_executor.PolarsExecutor()


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_reverse(
scalars_array_value: array_value.ArrayValue,
engine,
Expand All @@ -34,7 +34,7 @@ def test_engines_reverse(
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_double_reverse(
scalars_array_value: array_value.ArrayValue,
engine,
Expand All @@ -43,7 +43,7 @@ def test_engines_double_reverse(
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
@pytest.mark.parametrize(
"sort_col",
[
Expand All @@ -70,7 +70,7 @@ def test_engines_sort_over_column(
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_sort_multi_column_refs(
scalars_array_value: array_value.ArrayValue,
engine,
Expand Down