Skip to content

chore: add more grouping sets/rollup/cube tests #1029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 71 additions & 47 deletions tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@
from sqlalchemy.sql.functions import rollup, cube, grouping_sets


@pytest.fixture
def table(faux_conn, metadata):
# Fixture to create a sample table for testing

table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)

yield table

table.drop(faux_conn)


def test_constraints_are_ignored(faux_conn, metadata):
sqlalchemy.Table(
"ref",
Expand Down Expand Up @@ -282,85 +299,92 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata)
assert found_outer_sql == expected_outer_sql


def test_grouping_sets(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
grouping_ops = (
"grouping_op, grouping_op_func",
[("GROUPING SETS", grouping_sets), ("ROLLUP", rollup), ("CUBE", cube)],
)


@pytest.mark.parametrize(*grouping_ops)
def test_grouping_ops_vs_single_column(faux_conn, table, grouping_op, grouping_op_func):
# Tests each of the grouping ops against a single column

q = sqlalchemy.select(table.c.foo).group_by(grouping_op_func(table.c.foo))
found_sql = q.compile(faux_conn).string

expected_sql = (
f"SELECT `table1`.`foo` \n"
f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`)"
)

assert found_sql == expected_sql


@pytest.mark.parametrize(*grouping_ops)
def test_grouping_ops_vs_multi_columns(faux_conn, table, grouping_op, grouping_op_func):
# Tests each of the grouping ops against multiple columns

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
grouping_sets(table.c.foo, table.c.bar)
grouping_op_func(table.c.foo, table.c.bar)
)
found_sql = q.compile(faux_conn).string

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`)"
f"SELECT `table1`.`foo`, `table1`.`bar` \n"
f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`)"
)
found_sql = q.compile(faux_conn).string

assert found_sql == expected_sql


def test_rollup(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)
@pytest.mark.parametrize(*grouping_ops)
def test_grouping_op_with_grouping_op(faux_conn, table, grouping_op, grouping_op_func):
# Tests multiple grouping ops in a single statement

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
rollup(table.c.foo, table.c.bar)
grouping_op_func(table.c.foo, table.c.bar), grouping_op_func(table.c.foo)
)
found_sql = q.compile(faux_conn).string

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)"
f"SELECT `table1`.`foo`, `table1`.`bar` \n"
f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`), {grouping_op}(`table1`.`foo`)"
)
found_sql = q.compile(faux_conn).string

assert found_sql == expected_sql


def test_cube(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)
@pytest.mark.parametrize(*grouping_ops)
def test_grouping_ops_vs_group_by(faux_conn, table, grouping_op, grouping_op_func):
# Tests grouping op against regular group by statement

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
cube(table.c.foo, table.c.bar)
table.c.foo, grouping_op_func(table.c.bar)
)
found_sql = q.compile(faux_conn).string

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)"
f"SELECT `table1`.`foo`, `table1`.`bar` \n"
f"FROM `table1` GROUP BY `table1`.`foo`, {grouping_op}(`table1`.`bar`)"
)
found_sql = q.compile(faux_conn).string

assert found_sql == expected_sql


def test_multiple_grouping_sets(faux_conn, metadata):
table = setup_table(
faux_conn,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.Integer),
sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)),
)
@pytest.mark.parametrize(*grouping_ops)
def test_complex_grouping_ops_vs_nested_grouping_ops(
faux_conn, table, grouping_op, grouping_op_func
):
# Tests grouping ops nested within grouping ops

q = sqlalchemy.select(table.c.foo, table.c.bar).group_by(
grouping_sets(table.c.foo, table.c.bar), grouping_sets(table.c.foo)
grouping_sets(table.c.foo, grouping_op_func(table.c.bar))
)
found_sql = q.compile(faux_conn).string

expected_sql = (
"SELECT `table1`.`foo`, `table1`.`bar` \n"
"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`), GROUPING SETS(`table1`.`foo`)"
f"SELECT `table1`.`foo`, `table1`.`bar` \n"
f"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, {grouping_op}(`table1`.`bar`))"
)
found_sql = q.compile(faux_conn).string

assert found_sql == expected_sql