Skip to content

Commit

Permalink
fix: avoid escaping bind-like params containing colons (#17419)
Browse files Browse the repository at this point in the history
* fix: avoid escaping bind-like params containing colons

* fix query for mysql

* address comments
  • Loading branch information
villebro authored Nov 13, 2021
1 parent aa8040e commit ad8a7c4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 36 deletions.
20 changes: 14 additions & 6 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
Expand Down Expand Up @@ -103,6 +103,16 @@
VIRTUAL_TABLE_ALIAS = "virtual_table"


def text(clause: str) -> TextClause:
"""
SQLALchemy wrapper to ensure text clauses are escaped properly
:param clause: clause potentially containing colons
:return: text clause with escaped colons
"""
return sa.text(clause.replace(":", "\\:"))


class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
extra_cache_keys: List[Any]
Expand Down Expand Up @@ -806,7 +816,7 @@ def get_from_clause(
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
return TextAsFrom(sa.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
return TextAsFrom(text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)

def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
Expand All @@ -827,8 +837,6 @@ def get_rendered_sql(
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
# we need to escape strings that SQLAlchemy interprets as bind parameters
sql = utils.escape_sqla_query_binds(sql)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
Expand Down Expand Up @@ -1286,7 +1294,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
where_clause_and += [sa.text("({})".format(where))]
where_clause_and += [text(f"({where})")]
having = extras.get("having")
if having:
try:
Expand All @@ -1298,7 +1306,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
having_clause_and += [sa.text("({})".format(having))]
having_clause_and += [text(f"({having})")]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:
Expand Down
29 changes: 0 additions & 29 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
from typing_extensions import TypedDict, TypeGuard
Expand Down Expand Up @@ -132,8 +131,6 @@

InputType = TypeVar("InputType")

BIND_PARAM_REGEX = TextClause._bind_params_regex # pylint: disable=protected-access


class LenientEnum(Enum):
"""Enums with a `get` method that convert a enum value to `Enum` if it is a
Expand Down Expand Up @@ -1771,29 +1768,3 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
if limit != 0:
return min(max_limit, limit)
return max_limit


def escape_sqla_query_binds(sql: str) -> str:
"""
Replace strings in a query that SQLAlchemy would otherwise interpret as
bind parameters.
:param sql: unescaped query string
:return: escaped query string
>>> escape_sqla_query_binds("select ':foo'")
"select '\\\\:foo'"
>>> escape_sqla_query_binds("select 'foo'::TIMESTAMP")
"select 'foo'::TIMESTAMP"
>>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP")
"select '\\\\:foo \\\\:bar'::TIMESTAMP"
>>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP")
"select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP"
"""
matches = BIND_PARAM_REGEX.finditer(sql)
processed_binds = set()
for match in matches:
bind = match.group(0)
if bind not in processed_binds:
sql = sql.replace(bind, bind.replace(":", "\\:"))
processed_binds.add(bind)
return sql
59 changes: 58 additions & 1 deletion tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import unittest
from datetime import datetime
from io import BytesIO
from typing import Optional
from typing import Optional, List
from unittest import mock
from zipfile import is_zipfile, ZipFile

Expand All @@ -42,6 +42,7 @@
load_world_bank_dashboard_with_slices,
)
from tests.integration_tests.test_app import app
from superset import security_manager
from superset.charts.commands.data import ChartDataCommand
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
Expand All @@ -56,6 +57,7 @@
get_example_database,
get_example_default_schema,
get_main_database,
AdhocMetricExpressionType,
)
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType

Expand Down Expand Up @@ -2033,3 +2035,58 @@ def test_chart_data_series_limit(self):
self.assertEqual(
set(column for column in data[0].keys()), {"state", "name", "sum__num"}
)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_virtual_table_with_colons(self):
"""
Chart data API: test query with literal colon characters in query, metrics,
where clause and filters
"""
self.login(username="admin")
owner = self.get_user("admin").id
user = db.session.query(security_manager.user_model).get(owner)

table = SqlaTable(
table_name="virtual_table_1",
schema=get_example_default_schema(),
owners=[user],
database=get_example_database(),
sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names",
)
db.session.add(table)
db.session.commit()
table.fetch_metadata()

request_payload = get_query_context("birth_names")
request_payload["datasource"] = {
"type": "table",
"id": table.id,
}
request_payload["queries"][0]["columns"] = ["foo", "bar", "state"]
request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'"
request_payload["queries"][0]["orderby"] = None
request_payload["queries"][0]["metrics"] = [
{
"expressionType": AdhocMetricExpressionType.SQL,
"sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)",
"label": "count",
}
]
request_payload["queries"][0]["filters"] = [
{"col": "foo", "op": "!=", "val": ":qwerty:",}
]

rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
db.session.delete(table)
db.session.commit()
assert rv.status_code == 200
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
data = result["data"]
assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"}
# make sure results and query parameters are unescaped
assert {row["foo"] for row in data} == {":foo"}
assert {row["bar"] for row in data} == {":bar:"}
assert "':asdf'" in result["query"]
assert "':xyz:qwerty'" in result["query"]
assert "':qwerty:'" in result["query"]

0 comments on commit ad8a7c4

Please sign in to comment.