Skip to content

Commit

Permalink
fix: adhoc metrics (#30202)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored and mistercrunch committed Oct 28, 2024
1 parent 7dda9bd commit c2a407e
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 45 deletions.
2 changes: 2 additions & 0 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,7 @@ def adhoc_metric_to_sqla(
expression = self._process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down Expand Up @@ -1566,6 +1567,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
expression = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down
21 changes: 19 additions & 2 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetParseError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
Expand Down Expand Up @@ -112,6 +113,7 @@
def validate_adhoc_subquery(
sql: str,
database_id: int,
engine: str,
default_schema: str,
) -> str:
"""
Expand All @@ -126,7 +128,12 @@ def validate_adhoc_subquery(
"""
statements = []
for statement in sqlparse.parse(sql):
if has_table_query(statement):
try:
has_table = has_table_query(str(statement), engine)
except SupersetParseError:
has_table = True

if has_table:
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
raise SupersetSecurityException(
SupersetError(
Expand All @@ -135,7 +142,9 @@ def validate_adhoc_subquery(
level=ErrorLevel.ERROR,
)
)
# TODO (betodealmeida): reimplement with sqlglot
statement = insert_rls_in_predicate(statement, database_id, default_schema)

statements.append(statement)

return ";\n".join(str(statement) for statement in statements)
Expand Down Expand Up @@ -810,10 +819,11 @@ def get_sqla_row_level_filters(
# for datasources of type query
return []

def _process_sql_expression(
def _process_sql_expression( # pylint: disable=too-many-arguments
self,
expression: Optional[str],
database_id: int,
engine: str,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
Expand All @@ -823,6 +833,7 @@ def _process_sql_expression(
expression = validate_adhoc_subquery(
expression,
database_id,
engine,
schema,
)
try:
Expand Down Expand Up @@ -1108,6 +1119,7 @@ def adhoc_metric_to_sqla(
expression = self._process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down Expand Up @@ -1551,6 +1563,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
col["sqlExpression"] = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down Expand Up @@ -1613,6 +1626,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.database.backend,
self.schema,
)
outer = literal_column(f"({selected})")
Expand All @@ -1639,6 +1653,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
selected = validate_adhoc_subquery(
_sql,
self.database_id,
self.database.backend,
self.schema,
)

Expand Down Expand Up @@ -1915,6 +1930,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
where = self._process_sql_expression(
expression=where,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand All @@ -1933,6 +1949,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
having = self._process_sql_expression(
expression=having,
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down
1 change: 1 addition & 0 deletions superset/models/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def adhoc_column_to_sqla(
expression = self._process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
engine=self.database.backend,
schema=self.schema,
template_processor=template_processor,
)
Expand Down
40 changes: 13 additions & 27 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
extract_tables_from_statement,
SQLGLOT_DIALECTS,
SQLScript,
SQLStatement,
Table,
)
from superset.utils.backports import StrEnum
Expand Down Expand Up @@ -570,46 +571,31 @@ class InsertRLSState(StrEnum):
FOUND_TABLE = "FOUND_TABLE"


def has_table_query(token_list: TokenList) -> bool:
def has_table_query(expression: str, engine: str) -> bool:
"""
Return if a statement has a query reading from a table.
>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
>>> has_table_query("COUNT(*)", "postgresql")
False
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
>>> has_table_query("SELECT * FROM table", "postgresql")
True
Note that queries reading from constant values return false:
>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
>>> has_table_query("SELECT * FROM (SELECT 1)", "postgresql")
False
"""
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Ignore comments
if isinstance(token, sqlparse.sql.Comment):
continue

# Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
):
return True
# Remove trailing semicolon.
expression = expression.strip().rstrip(";")

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING
# Wrap the expression in parentheses if it's not already.
if not expression.startswith("("):
expression = f"({expression})"

return False
sql = f"SELECT {expression}"
statement = SQLStatement(sql, engine)
return any(statement.tables)


def add_table_name(rls: TokenList, table: str) -> None:
Expand Down
4 changes: 4 additions & 0 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_main_database,
)
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
Expand Down Expand Up @@ -585,6 +586,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
assert "INCORRECT SQL" in rv.json.get("error")


@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
Expand Down Expand Up @@ -649,6 +651,7 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
assert rv.json["result"]["rowcount"] == 0


@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
Expand All @@ -669,6 +672,7 @@ def test_get_samples_with_time_filter(test_client, login_as_admin, physical_data
assert rv.json["result"]["total_count"] == 2


@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_get_samples_with_multiple_filters(
test_client, login_as_admin, physical_dataset
):
Expand Down
7 changes: 6 additions & 1 deletion tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
)
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.conftest import only_postgresql, only_sqlite
from tests.integration_tests.conftest import (
only_postgresql,
only_sqlite,
with_feature_flags,
)
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
load_birth_names_data, # noqa: F401
Expand Down Expand Up @@ -858,6 +862,7 @@ def test_non_time_column_with_time_grain(app_context, physical_dataset):
assert df["COL2 ALIAS"][0] == "a"


@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
def test_special_chars_in_column_name(app_context, physical_dataset):
qc = QueryContextFactory().create(
datasource={
Expand Down
50 changes: 35 additions & 15 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,46 +1286,66 @@ def test_sqlparse_issue_652():


@pytest.mark.parametrize(
"sql,expected",
("engine", "sql", "expected"),
[
("SELECT * FROM table", True),
("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True),
("COUNT(*)", False),
("SELECT a FROM (SELECT 1 AS a)", False),
("SELECT a FROM (SELECT 1 AS a) JOIN table", True),
("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False),
("SELECT * FROM other_table", True),
("extract(HOUR from from_unixtime(hour_ts)", False),
("(SELECT * FROM table)", True),
("(SELECT COUNT(DISTINCT name) from birth_names)", True),
("postgresql", "extract(HOUR from from_unixtime(hour_ts))", False),
("postgresql", "SELECT * FROM table", True),
("postgresql", "(SELECT * FROM table)", True),
(
"postgresql",
"SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)",
True,
),
(
"postgresql",
"(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)",
True,
),
("postgresql", "COUNT(*)", False),
("postgresql", "SELECT a FROM (SELECT 1 AS a)", False),
("postgresql", "SELECT a FROM (SELECT 1 AS a) JOIN table", True),
(
"postgresql",
"SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar",
False,
),
("postgresql", "SELECT * FROM other_table", True),
("postgresql", "(SELECT COUNT(DISTINCT name) from birth_names)", True),
(
"postgresql",
"(SELECT table_name FROM information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"postgresql",
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"postgresql",
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
True,
),
(
"postgresql",
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
True,
),
(
"postgresql",
"((select users.id from (select 'majorie' as a) b, users where b.a = users.name and users.name in ('majorie') limit 1) like 'U%')",
True,
),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
def test_has_table_query(engine: str, sql: str, expected: bool) -> None:
"""
Test if a given statement queries a table.
This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing
row-level security.
"""
statement = sqlparse.parse(sql)[0]
assert has_table_query(statement) == expected
assert has_table_query(sql, engine) == expected


@pytest.mark.parametrize(
Expand Down

0 comments on commit c2a407e

Please sign in to comment.