Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Mapper, validates
from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause
from sqlalchemy.sql.elements import ColumnElement, Grouping, literal_column, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
from sqlalchemy_utils import UUIDType
Expand Down Expand Up @@ -2980,7 +2980,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
where_clause_and: list[ColumnElement] = []
having_clause_and: list[ColumnElement] = []

for flt in filter: # type: ignore
for flt in filter or []:
if not all(flt.get(s) for s in ["col", "op"]):
continue
flt_col = flt["col"]
Expand Down Expand Up @@ -3221,7 +3221,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
schema=self.schema,
template_processor=template_processor,
)
where_clause_and += [self.text(where)]
where_clause_and += [Grouping(self.text(where))]
having = extras.get("having")
if having:
having = self._process_select_expression(
Expand All @@ -3231,7 +3231,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
schema=self.schema,
template_processor=template_processor,
)
having_clause_and += [self.text(having)]
having_clause_and += [Grouping(self.text(having))]

if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(
Expand Down
137 changes: 125 additions & 12 deletions tests/unit_tests/models/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,9 @@ def condition_factory(col_name: str, expr):
has_single_quotes = "'Others'" in select_sql and "'Others'" in groupby_sql
has_double_quotes = '"Others"' in select_sql and '"Others"' in groupby_sql

assert has_single_quotes or has_double_quotes, (
"Others literal should be quoted with either single or double quotes"
)
assert (
has_single_quotes or has_double_quotes
), "Others literal should be quoted with either single or double quotes"

# Verify the structure of the generated SQL
assert "CASE WHEN" in select_sql
Expand Down Expand Up @@ -1121,13 +1121,13 @@ def test_process_select_expression_end_to_end(database: Database) -> None:
# sqlglot may normalize the SQL slightly, so we check the result exists
# and doesn't contain the SELECT prefix
assert result is not None, f"Failed to process: {expression}"
assert not result.upper().startswith("SELECT"), (
f"Result still has SELECT prefix: {result}"
)
assert not result.upper().startswith(
"SELECT"
), f"Result still has SELECT prefix: {result}"
# The result should contain the core expression (case-insensitive check)
assert expected.replace(" ", "").lower() in result.replace(" ", "").lower(), (
f"Expected '{expected}' to be in result '{result}' for input '{expression}'"
)
assert (
expected.replace(" ", "").lower() in result.replace(" ", "").lower()
), f"Expected '{expected}' to be in result '{result}' for input '{expression}'"


def test_reapply_query_filters_with_granularity(database: Database) -> None:
Expand Down Expand Up @@ -1641,9 +1641,9 @@ def test_adhoc_column_with_spaces_generates_quoted_sql(database: Database) -> No
)
)

assert '"Order Total"' in sql_numeric, (
f"Expected quoted column name in SQL: {sql_numeric}"
)
assert (
'"Order Total"' in sql_numeric
), f"Expected quoted column name in SQL: {sql_numeric}"


def test_adhoc_column_with_spaces_in_full_query(database: Database) -> None:
Expand Down Expand Up @@ -1743,3 +1743,116 @@ def test_orderby_adhoc_column(database: Database) -> None:
# Verify the SQL contains the expression from the adhoc column
sql = str(result.sqla_query)
assert "ORDER BY" in sql.upper()


def test_extras_where_is_parenthesized(
database: Database,
) -> None:
"""
Test that extras.where is wrapped in parentheses when composed with other
filters.

Without parentheses, an extras.where containing OR operators combined
with other filters via AND could produce unexpected evaluation order due
to SQL operator precedence (AND binds tighter than OR). Wrapping in
parentheses ensures the expression is treated as a single logical unit.
"""
from unittest.mock import patch

from sqlalchemy import text as sa_text

from superset.connectors.sqla.models import SqlaTable, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(column_name="b", type="TEXT"),
],
)

with (
patch.object(
table,
"get_sqla_row_level_filters",
return_value=[sa_text("(b = 'restricted')")],
),
patch.object(
table,
"_process_select_expression",
return_value="1 = 1 OR 1 = 1",
),
):
sqla_query = table.get_sqla_query(
columns=["a"],
extras={"where": "1=1 OR 1=1"},
is_timeseries=False,
metrics=[],
)

with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)

assert "(1 = 1 OR 1 = 1)" in sql, (
f"extras.where should be wrapped in parentheses. " f"Generated SQL: {sql}"
)

assert (
"b = 'restricted'" in sql
), f"Additional filters should be present in query. Generated SQL: {sql}"


def test_extras_having_is_parenthesized(
database: Database,
) -> None:
"""
Test that extras.having is wrapped in parentheses when composed with
other HAVING filters, to ensure correct evaluation order.
"""
from unittest.mock import patch

from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a", type="INTEGER"),
TableColumn(column_name="b", type="TEXT"),
],
metrics=[
SqlMetric(metric_name="cnt", expression="COUNT(*)"),
],
)

with patch.object(
table,
"_process_select_expression",
return_value="COUNT(*) > 0 OR 1 = 1",
):
sqla_query = table.get_sqla_query(
groupby=["b"],
metrics=["cnt"],
extras={"having": "COUNT(*) > 0 OR 1=1"},
is_timeseries=False,
)

with database.get_sqla_engine() as engine:
sql = str(
sqla_query.sqla_query.compile(
dialect=engine.dialect,
compile_kwargs={"literal_binds": True},
)
)

assert "(COUNT(*) > 0 OR 1 = 1)" in sql, (
f"extras.having should be wrapped in parentheses. " f"Generated SQL: {sql}"
)
Loading