Skip to content

Commit

Permalink
fix(dashboard): Charts crashing when cross filter on adhoc column is …
Browse files Browse the repository at this point in the history
…applied (#23238)

Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
  • Loading branch information
kgabryje and villebro authored Mar 4, 2023
1 parent 006f3dd commit 42980a6
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
FeatureFlag,
Filters,
FilterState,
getColumnLabel,
isFeatureEnabled,
NativeFilterType,
NO_TIME_RANGE,
Expand Down Expand Up @@ -146,8 +147,8 @@ const getAppliedColumns = (chart: any): Set<string> =>

const getRejectedColumns = (chart: any): Set<string> =>
new Set(
(chart?.queriesResponse?.[0]?.rejected_filters || []).map(
(filter: any) => filter.column,
(chart?.queriesResponse?.[0]?.rejected_filters || []).map((filter: any) =>
getColumnLabel(filter.column),
),
);

Expand Down
26 changes: 12 additions & 14 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import copy
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING

from flask_babel import _

Expand All @@ -32,7 +32,6 @@
ExtraFiltersReasonType,
get_column_name,
get_time_filter_status,
is_adhoc_column,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -102,7 +101,6 @@ def _get_full(
datasource = _get_datasource(query_context, query_obj)
result_type = query_obj.result_type or query_context.result_type
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
applied_template_filters = payload.get("applied_template_filters", [])
df = payload["df"]
status = payload["status"]
if status != QueryStatus.FAILED:
Expand All @@ -113,23 +111,23 @@ def _get_full(
payload["result_format"] = query_context.result_format
del payload["df"]

filters = query_obj.filter
filter_columns = cast(List[str], [flt.get("col") for flt in filters])
columns = set(datasource.column_names)
applied_time_columns, rejected_time_columns = get_time_filter_status(
datasource, query_obj.applied_time_extras
)

applied_filter_columns = payload.get("applied_filter_columns", [])
rejected_filter_columns = payload.get("rejected_filter_columns", [])
del payload["applied_filter_columns"]
del payload["rejected_filter_columns"]
payload["applied_filters"] = [
{"column": get_column_name(col)}
for col in filter_columns
if is_adhoc_column(col) or col in columns or col in applied_template_filters
{"column": get_column_name(col)} for col in applied_filter_columns
] + applied_time_columns
payload["rejected_filters"] = [
{"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
for col in filter_columns
if not is_adhoc_column(col)
and col not in columns
and col not in applied_template_filters
{
"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE,
"column": get_column_name(col),
}
for col in rejected_filter_columns
] + rejected_time_columns

if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED:
Expand Down
2 changes: 2 additions & 0 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def get_df_payload(
"cache_timeout": self.get_cache_timeout(),
"df": cache.df,
"applied_template_filters": cache.applied_template_filters,
"applied_filter_columns": cache.applied_filter_columns,
"rejected_filter_columns": cache.rejected_filter_columns,
"annotation_data": cache.annotation_data,
"error": cache.error_message,
"is_cached": cache.is_cached,
Expand Down
15 changes: 15 additions & 0 deletions superset/common/utils/query_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from superset.extensions import cache_manager
from superset.models.helpers import QueryResult
from superset.stats_logger import BaseStatsLogger
from superset.superset_typing import Column
from superset.utils.cache import set_and_log_cache
from superset.utils.core import error_msg_from_exception, get_stacktrace

Expand All @@ -54,6 +55,8 @@ def __init__(
query: str = "",
annotation_data: Optional[Dict[str, Any]] = None,
applied_template_filters: Optional[List[str]] = None,
applied_filter_columns: Optional[List[Column]] = None,
rejected_filter_columns: Optional[List[Column]] = None,
status: Optional[str] = None,
error_message: Optional[str] = None,
is_loaded: bool = False,
Expand All @@ -66,6 +69,8 @@ def __init__(
self.query = query
self.annotation_data = {} if annotation_data is None else annotation_data
self.applied_template_filters = applied_template_filters or []
self.applied_filter_columns = applied_filter_columns or []
self.rejected_filter_columns = rejected_filter_columns or []
self.status = status
self.error_message = error_message

Expand Down Expand Up @@ -93,6 +98,8 @@ def set_query_result(
self.status = query_result.status
self.query = query_result.query
self.applied_template_filters = query_result.applied_template_filters
self.applied_filter_columns = query_result.applied_filter_columns
self.rejected_filter_columns = query_result.rejected_filter_columns
self.error_message = query_result.error_message
self.df = query_result.df
self.annotation_data = {} if annotation_data is None else annotation_data
Expand All @@ -107,6 +114,8 @@ def set_query_result(
"df": self.df,
"query": self.query,
"applied_template_filters": self.applied_template_filters,
"applied_filter_columns": self.applied_filter_columns,
"rejected_filter_columns": self.rejected_filter_columns,
"annotation_data": self.annotation_data,
}
if self.is_loaded and key and self.status != QueryStatus.FAILED:
Expand Down Expand Up @@ -150,6 +159,12 @@ def get(
query_cache.applied_template_filters = cache_value.get(
"applied_template_filters", []
)
query_cache.applied_filter_columns = cache_value.get(
"applied_filter_columns", []
)
query_cache.rejected_filter_columns = cache_value.get(
"rejected_filter_columns", []
)
query_cache.status = QueryStatus.SUCCESS
query_cache.is_loaded = True
query_cache.is_cached = cache_value is not None
Expand Down
57 changes: 48 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.exceptions import (
AdvancedDataTypeResponseError,
ColumnNotFoundException,
DatasetInvalidPermissionEvaluationException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetGenericDBErrorException,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
Expand Down Expand Up @@ -150,6 +152,8 @@

class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
applied_filter_columns: List[ColumnTyping]
rejected_filter_columns: List[ColumnTyping]
cte: Optional[str]
extra_cache_keys: List[Any]
labels_expected: List[str]
Expand All @@ -159,6 +163,8 @@ class SqlaQuery(NamedTuple):

class QueryStringExtended(NamedTuple):
applied_template_filters: Optional[List[str]]
applied_filter_columns: List[ColumnTyping]
rejected_filter_columns: List[ColumnTyping]
labels_expected: List[str]
prequeries: List[str]
sql: str
Expand Down Expand Up @@ -878,6 +884,8 @@ def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExten
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
applied_filter_columns=sqlaq.applied_filter_columns,
rejected_filter_columns=sqlaq.rejected_filter_columns,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
Expand Down Expand Up @@ -1020,13 +1028,16 @@ def adhoc_column_to_sqla(
)
is_dttm = col_in_metadata.is_temporal
else:
sqla_column = literal_column(expression)
# probe adhoc column type
tbl, _ = self.get_from_clause(template_processor)
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
col_desc = get_columns_description(self.database, sql)
is_dttm = col_desc[0]["is_dttm"]
try:
sqla_column = literal_column(expression)
# probe adhoc column type
tbl, _ = self.get_from_clause(template_processor)
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
col_desc = get_columns_description(self.database, sql)
is_dttm = col_desc[0]["is_dttm"]
except SupersetGenericDBErrorException as ex:
raise ColumnNotFoundException(message=str(ex)) from ex

if (
is_dttm
Expand Down Expand Up @@ -1181,6 +1192,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
}
columns = columns or []
groupby = groupby or []
rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
series_column_names = utils.get_column_names(series_columns or [])
# deprecated, to be removed in 2.0
if is_timeseries and timeseries_limit:
Expand Down Expand Up @@ -1439,9 +1452,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
col_obj = dttm_col
elif is_adhoc_column(flt_col):
sqla_col = self.adhoc_column_to_sqla(flt_col)
try:
sqla_col = self.adhoc_column_to_sqla(flt_col)
applied_adhoc_filters_columns.append(flt_col)
except ColumnNotFoundException:
rejected_adhoc_filters_columns.append(flt_col)
continue
else:
col_obj = columns_by_name.get(flt_col)
col_obj = columns_by_name.get(cast(str, flt_col))
filter_grain = flt.get("grain")

if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
Expand Down Expand Up @@ -1766,8 +1784,27 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
qry = select([col]).select_from(qry.alias("rowcount_qry"))
labels_expected = [label]

filter_columns = [flt.get("col") for flt in filter] if filter else []
rejected_filter_columns = [
col
for col in filter_columns
if col
and not is_adhoc_column(col)
and col not in self.column_names
and col not in applied_template_filters
] + rejected_adhoc_filters_columns
applied_filter_columns = [
col
for col in filter_columns
if col
and not is_adhoc_column(col)
and (col in self.column_names or col in applied_template_filters)
] + applied_adhoc_filters_columns

return SqlaQuery(
applied_template_filters=applied_template_filters,
rejected_filter_columns=rejected_filter_columns,
applied_filter_columns=applied_filter_columns,
cte=cte,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
Expand Down Expand Up @@ -1906,6 +1943,8 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:

return QueryResult(
applied_template_filters=query_str_ext.applied_template_filters,
applied_filter_columns=query_str_ext.applied_filter_columns,
rejected_filter_columns=query_str_ext.rejected_filter_columns,
status=status,
df=df,
duration=datetime.now() - qry_start_dttm,
Expand Down
4 changes: 4 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,7 @@ class SupersetCancelQueryException(SupersetException):

class QueryNotFoundException(SupersetException):
status = 404


class ColumnNotFoundException(SupersetException):
status = 404
7 changes: 6 additions & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, sanitize_clause
from superset.superset_typing import (
AdhocMetric,
Column as ColumnTyping,
FilterValue,
FilterValues,
Metric,
Expand Down Expand Up @@ -545,6 +546,8 @@ def __init__( # pylint: disable=too-many-arguments
query: str,
duration: timedelta,
applied_template_filters: Optional[List[str]] = None,
applied_filter_columns: Optional[List[ColumnTyping]] = None,
rejected_filter_columns: Optional[List[ColumnTyping]] = None,
status: str = QueryStatus.SUCCESS,
error_message: Optional[str] = None,
errors: Optional[List[Dict[str, Any]]] = None,
Expand All @@ -555,6 +558,8 @@ def __init__( # pylint: disable=too-many-arguments
self.query = query
self.duration = duration
self.applied_template_filters = applied_template_filters or []
self.applied_filter_columns = applied_filter_columns or []
self.rejected_filter_columns = rejected_filter_columns or []
self.status = status
self.error_message = error_message
self.errors = errors or []
Expand Down Expand Up @@ -1646,7 +1651,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
elif utils.is_adhoc_column(flt_col):
sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore
else:
col_obj = columns_by_name.get(flt_col)
col_obj = columns_by_name.get(cast(str, flt_col))
filter_grain = flt.get("grain")

if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
Expand Down
4 changes: 2 additions & 2 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class AdhocFilterClause(TypedDict, total=False):


class QueryObjectFilterClause(TypedDict, total=False):
col: str
col: Column
op: str # pylint: disable=invalid-name
val: Optional[FilterValues]
grain: Optional[str]
Expand Down Expand Up @@ -1089,7 +1089,7 @@ def simple_filter_to_adhoc(
"expressionType": "SIMPLE",
"comparator": filter_clause.get("val"),
"operator": filter_clause["op"],
"subject": filter_clause["col"],
"subject": cast(str, filter_clause["col"]),
}
if filter_clause.get("isExtra"):
result["isExtra"] = True
Expand Down
33 changes: 17 additions & 16 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def __init__(
self.status: Optional[str] = None
self.error_msg = ""
self.results: Optional[QueryResult] = None
self.applied_template_filters: List[str] = []
self.applied_filter_columns: List[Column] = []
self.rejected_filter_columns: List[Column] = []
self.errors: List[Dict[str, Any]] = []
self.force = force
self._force_cached = force_cached
Expand Down Expand Up @@ -288,7 +289,8 @@ def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:

# The datasource here can be different backend but the interface is common
self.results = self.datasource.query(query_obj)
self.applied_template_filters = self.results.applied_template_filters or []
self.applied_filter_columns = self.results.applied_filter_columns or []
self.rejected_filter_columns = self.results.rejected_filter_columns or []
self.query = self.results.query
self.status = self.results.status
self.errors = self.results.errors
Expand Down Expand Up @@ -492,25 +494,21 @@ def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload
if "df" in payload:
del payload["df"]

filters = self.form_data.get("filters", [])
filter_columns = [flt.get("col") for flt in filters]
columns = set(self.datasource.column_names)
applied_template_filters = self.applied_template_filters or []
applied_filter_columns = self.applied_filter_columns or []
rejected_filter_columns = self.rejected_filter_columns or []
applied_time_extras = self.form_data.get("applied_time_extras", {})
applied_time_columns, rejected_time_columns = utils.get_time_filter_status(
self.datasource, applied_time_extras
)
payload["applied_filters"] = [
{"column": get_column_name(col)}
for col in filter_columns
if is_adhoc_column(col) or col in columns or col in applied_template_filters
{"column": get_column_name(col)} for col in applied_filter_columns
] + applied_time_columns
payload["rejected_filters"] = [
{"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
for col in filter_columns
if not is_adhoc_column(col)
and col not in columns
and col not in applied_template_filters
{
"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE,
"column": get_column_name(col),
}
for col in rejected_filter_columns
] + rejected_time_columns
if df is not None:
payload["colnames"] = list(df.columns)
Expand All @@ -535,8 +533,11 @@ def get_df_payload( # pylint: disable=too-many-statements
try:
df = cache_value["df"]
self.query = cache_value["query"]
self.applied_template_filters = cache_value.get(
"applied_template_filters", []
self.applied_filter_columns = cache_value.get(
"applied_filter_columns", []
)
self.rejected_filter_columns = cache_value.get(
"rejected_filter_columns", []
)
self.status = QueryStatus.SUCCESS
is_loaded = True
Expand Down
Loading

0 comments on commit 42980a6

Please sign in to comment.