Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mssql): support cte in virtual tables #18567

Merged
merged 8 commits into from
Feb 10, 2022
47 changes: 35 additions & 12 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
get_physical_table_metadata,
get_virtual_table_metadata,
)
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
from superset.exceptions import QueryObjectValidationError
from superset.jinja_context import (
BaseTemplateProcessor,
Expand Down Expand Up @@ -107,6 +107,7 @@

class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
cte: Optional[str]
extra_cache_keys: List[Any]
labels_expected: List[str]
prequeries: List[str]
Expand Down Expand Up @@ -562,6 +563,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def __repr__(self) -> str:
return self.name

@staticmethod
def _apply_cte(sql: str, cte: Optional[str]) -> str:
"""
Append a CTE before the SELECT statement if defined

:param sql: SELECT statement
:param cte: CTE statement
:return:
"""
if cte:
sql = f"{cte}\n{sql}"
return sql

@property
def db_engine_spec(self) -> Type[BaseEngineSpec]:
return self.database.db_engine_spec
Expand Down Expand Up @@ -743,20 +757,18 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)

qry = (
select([target_col.get_sqla_col()])
.select_from(self.get_from_clause(tp))
.distinct()
)
qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct()
if limit:
qry = qry.limit(limit)

if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())

engine = self.database.get_sqla_engine()
sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": True}))
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)

df = pd.read_sql_query(sql=sql, con=engine)
Expand All @@ -778,6 +790,7 @@ def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
Expand All @@ -800,13 +813,14 @@ def get_sqla_table(self) -> TableClause:

def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> Union[TableClause, Alias]:
) -> Tuple[Union[TableClause, Alias], Optional[str]]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery.
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""
if not self.is_virtual:
return self.get_sqla_table()
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
Expand All @@ -817,7 +831,15 @@ def get_from_clause(
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
return TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)

cte = self.db_engine_spec.get_cte_query(from_sql)
from_clause = (
table(CTE_ALIAS)
if cte
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
)

return from_clause, cte

def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
Expand Down Expand Up @@ -1224,7 +1246,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma

qry = sa.select(select_exprs)

tbl = self.get_from_clause(template_processor)
tbl, cte = self.get_from_clause(template_processor)

if groupby_all_columns:
qry = qry.group_by(*groupby_all_columns.values())
Expand Down Expand Up @@ -1491,6 +1513,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma

return SqlaQuery(
applied_template_filters=applied_template_filters,
cte=cte,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
sqla_query=qry,
Expand Down
34 changes: 34 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
from sqlparse.tokens import CTE
from typing_extensions import TypedDict

from superset import security_manager, sql_parse
Expand All @@ -80,6 +81,9 @@
logger = logging.getLogger()


CTE_ALIAS = "__cte"


class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
label: str
Expand Down Expand Up @@ -292,6 +296,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# But for backward compatibility, False by default
allows_hidden_cc_in_orderby = False

# Whether allow CTE as subquery or regular CTE
# If True, then it will allow in subquery ,
# if False it will allow as regular CTE
allows_cte_in_subquery = True

force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
Expand Down Expand Up @@ -663,6 +672,31 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.set_or_update_query_limit(limit)

@classmethod
def get_cte_query(cls, sql: str) -> Optional[str]:
"""
Convert the input CTE based SQL to the SQL for virtual table conversion

:param sql: SQL query
:return: CTE with the main select query aliased as `__cte`

"""
if not cls.allows_cte_in_subquery:
stmt = sqlparse.parse(sql)[0]

# The first meaningful token for CTE will be with WITH
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
if not (token and token.ttype == CTE):
return None
idx, token = stmt.token_next(idx)
idx = stmt.token_index(token) + 1

# extract rest of the SQLs after CTE
remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip()
return f"WITH {token.value},\n{CTE_ALIAS} AS (\n{remainder}\n)"

return None

@classmethod
def df_to_sql(
cls,
Expand Down
1 change: 1 addition & 0 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class MssqlEngineSpec(BaseEngineSpec):
engine_name = "Microsoft SQL Server"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128
allows_cte_in_subquery = False

_time_grain_expressions = {
None: "{col}",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def test_comments_in_sqlatable_query(self):
sql=commented_query,
database=get_example_database(),
)
rendered_query = str(table.get_from_clause())
rendered_query = str(table.get_from_clause()[0])
self.assertEqual(clean_query, rendered_query)

def test_slice_payload_no_datasource(self):
Expand Down
43 changes: 43 additions & 0 deletions tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access

from textwrap import dedent

import pytest
from flask.ctx import AppContext
from sqlalchemy.types import TypeEngine


def test_get_text_clause_with_colon(app_context: AppContext) -> None:
Expand Down Expand Up @@ -56,3 +60,42 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None:
"SELECT foo FROM tbl1",
"SELECT bar FROM tbl2",
]


@pytest.mark.parametrize(
"original,expected",
[
(
dedent(
"""
with currency as
(
select 'INR' as cur
)
select * from currency
"""
),
None,
),
("SELECT 1 as cnt", None,),
(
dedent(
"""
select 'INR' as cur
union
select 'AUD' as cur
union
select 'USD' as cur
"""
),
None,
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
from superset.db_engine_specs.base import BaseEngineSpec

actual = BaseEngineSpec.get_cte_query(original)
assert actual == expected
51 changes: 51 additions & 0 deletions tests/unit_tests/db_engine_specs/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,57 @@ def test_column_datatype_to_string(
assert actual == expected


@pytest.mark.parametrize(
"original,expected",
[
(
dedent(
"""
with currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
)
select * from currency union all select * from currency_2
"""
),
dedent(
"""WITH currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
),
__cte AS (
select * from currency union all select * from currency_2
)"""
),
),
("SELECT 1 as cnt", None,),
(
dedent(
"""
select 'INR' as cur
union
select 'AUD' as cur
union
select 'USD' as cur
"""
),
None,
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec

actual = MssqlEngineSpec.get_cte_query(original)
assert actual == expected


def test_extract_errors(app_context: AppContext) -> None:
"""
Test that custom error messages are extracted correctly.
Expand Down