From 224f962e4641bbfe6eff807745262fc03a0f9b92 Mon Sep 17 00:00:00 2001 From: Kamil Gabryjelski Date: Sat, 4 Mar 2023 07:57:35 +0100 Subject: [PATCH] fix(dashboard): Charts crashing when cross filter on adhoc column is applied (#23238) Co-authored-by: Ville Brofeldt <33317356+villebro@users.noreply.github.com> (cherry picked from commit 42980a69a72a27a948f7713e5a93a4a2eaa01d2d) --- .../components/FiltersBadge/selectors.ts | 5 +- superset/common/query_actions.py | 26 ++++----- superset/common/query_context_processor.py | 2 + superset/common/utils/query_cache_manager.py | 15 +++++ superset/connectors/sqla/models.py | 57 ++++++++++++++++--- superset/exceptions.py | 4 ++ superset/models/helpers.py | 7 ++- superset/utils/core.py | 4 +- superset/viz.py | 33 +++++------ .../charts/data/api_tests.py | 34 +++++++++++ 10 files changed, 143 insertions(+), 44 deletions(-) diff --git a/superset-frontend/src/dashboard/components/FiltersBadge/selectors.ts b/superset-frontend/src/dashboard/components/FiltersBadge/selectors.ts index c0916b99607b0..582e5e5ea90a1 100644 --- a/superset-frontend/src/dashboard/components/FiltersBadge/selectors.ts +++ b/superset-frontend/src/dashboard/components/FiltersBadge/selectors.ts @@ -23,6 +23,7 @@ import { FeatureFlag, Filters, FilterState, + getColumnLabel, isFeatureEnabled, NativeFilterType, NO_TIME_RANGE, @@ -145,8 +146,8 @@ const getAppliedColumns = (chart: any): Set => const getRejectedColumns = (chart: any): Set => new Set( - (chart?.queriesResponse?.[0]?.rejected_filters || []).map( - (filter: any) => filter.column, + (chart?.queriesResponse?.[0]?.rejected_filters || []).map((filter: any) => + getColumnLabel(filter.column), ), ); diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index bfb3d368789d9..38526475b9349 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -17,7 +17,7 @@ from __future__ import annotations import copy -from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING from flask_babel import _ @@ -32,7 +32,6 @@ ExtraFiltersReasonType, get_column_name, get_time_filter_status, - is_adhoc_column, ) if TYPE_CHECKING: @@ -102,7 +101,6 @@ def _get_full( datasource = _get_datasource(query_context, query_obj) result_type = query_obj.result_type or query_context.result_type payload = query_context.get_df_payload(query_obj, force_cached=force_cached) - applied_template_filters = payload.get("applied_template_filters", []) df = payload["df"] status = payload["status"] if status != QueryStatus.FAILED: @@ -113,23 +111,23 @@ def _get_full( payload["result_format"] = query_context.result_format del payload["df"] - filters = query_obj.filter - filter_columns = cast(List[str], [flt.get("col") for flt in filters]) - columns = set(datasource.column_names) applied_time_columns, rejected_time_columns = get_time_filter_status( datasource, query_obj.applied_time_extras ) + + applied_filter_columns = payload.get("applied_filter_columns", []) + rejected_filter_columns = payload.get("rejected_filter_columns", []) + del payload["applied_filter_columns"] + del payload["rejected_filter_columns"] payload["applied_filters"] = [ - {"column": get_column_name(col)} - for col in filter_columns - if is_adhoc_column(col) or col in columns or col in applied_template_filters + {"column": get_column_name(col)} for col in applied_filter_columns ] + applied_time_columns payload["rejected_filters"] = [ - {"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col} - for col in filter_columns - if not is_adhoc_column(col) - and col not in columns - and col not in applied_template_filters + { + "reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, + "column": get_column_name(col), + } + for col in rejected_filter_columns ] + rejected_time_columns if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED: diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 77ca69fcf6f02..703e1d71ddeaa 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -165,6 +165,8 @@ def get_df_payload( "cache_timeout": self.get_cache_timeout(), "df": cache.df, "applied_template_filters": cache.applied_template_filters, + "applied_filter_columns": cache.applied_filter_columns, + "rejected_filter_columns": cache.rejected_filter_columns, "annotation_data": cache.annotation_data, "error": cache.error_message, "is_cached": cache.is_cached, diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py index 76aa5ddef32e3..6060bb76457fc 100644 --- a/superset/common/utils/query_cache_manager.py +++ b/superset/common/utils/query_cache_manager.py @@ -29,6 +29,7 @@ from superset.extensions import cache_manager from superset.models.helpers import QueryResult from superset.stats_logger import BaseStatsLogger +from superset.superset_typing import Column from superset.utils.cache import set_and_log_cache from superset.utils.core import error_msg_from_exception, get_stacktrace @@ -54,6 +55,8 @@ def __init__( query: str = "", annotation_data: Optional[Dict[str, Any]] = None, applied_template_filters: Optional[List[str]] = None, + applied_filter_columns: Optional[List[Column]] = None, + rejected_filter_columns: Optional[List[Column]] = None, status: Optional[str] = None, error_message: Optional[str] = None, is_loaded: bool = False, @@ -66,6 +69,8 @@ def __init__( self.query = query self.annotation_data = {} if annotation_data is None else annotation_data self.applied_template_filters = applied_template_filters or [] + self.applied_filter_columns = applied_filter_columns or [] + self.rejected_filter_columns = rejected_filter_columns or [] self.status = status self.error_message = error_message @@ -93,6 +98,8 @@ def set_query_result( self.status = query_result.status self.query = query_result.query self.applied_template_filters = query_result.applied_template_filters + self.applied_filter_columns = query_result.applied_filter_columns + self.rejected_filter_columns = query_result.rejected_filter_columns self.error_message = query_result.error_message self.df = query_result.df self.annotation_data = {} if annotation_data is None else annotation_data @@ -107,6 +114,8 @@ def set_query_result( "df": self.df, "query": self.query, "applied_template_filters": self.applied_template_filters, + "applied_filter_columns": self.applied_filter_columns, + "rejected_filter_columns": self.rejected_filter_columns, "annotation_data": self.annotation_data, } if self.is_loaded and key and self.status != QueryStatus.FAILED: @@ -150,6 +159,12 @@ def get( query_cache.applied_template_filters = cache_value.get( "applied_template_filters", [] ) + query_cache.applied_filter_columns = cache_value.get( + "applied_filter_columns", [] + ) + query_cache.rejected_filter_columns = cache_value.get( + "rejected_filter_columns", [] + ) query_cache.status = QueryStatus.SUCCESS query_cache.is_loaded = True query_cache.is_cached = cache_value is not None diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8be079bde21c5..95f9121102a87 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -99,9 +99,11 @@ from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression from superset.exceptions import ( AdvancedDataTypeResponseError, + ColumnNotFoundException, DatasetInvalidPermissionEvaluationException, QueryClauseValidationException, QueryObjectValidationError, + SupersetGenericDBErrorException, SupersetSecurityException, ) from superset.extensions import feature_flag_manager @@ -150,6 +152,8 @@ class SqlaQuery(NamedTuple): applied_template_filters: List[str] + applied_filter_columns: List[ColumnTyping] + rejected_filter_columns: List[ColumnTyping] cte: Optional[str] extra_cache_keys: List[Any] labels_expected: List[str] @@ -159,6 +163,8 @@ class SqlaQuery(NamedTuple): class QueryStringExtended(NamedTuple): applied_template_filters: Optional[List[str]] + applied_filter_columns: List[ColumnTyping] + rejected_filter_columns: List[ColumnTyping] labels_expected: List[str] prequeries: List[str] sql: str @@ -882,6 +888,8 @@ def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExten sql = self.mutate_query_from_config(sql) return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, + applied_filter_columns=sqlaq.applied_filter_columns, + rejected_filter_columns=sqlaq.rejected_filter_columns, labels_expected=sqlaq.labels_expected, prequeries=sqlaq.prequeries, sql=sql, @@ -1024,13 +1032,16 @@ def adhoc_column_to_sqla( ) is_dttm = col_in_metadata.is_temporal else: - sqla_column = literal_column(expression) - # probe adhoc column type - tbl, _ = self.get_from_clause(template_processor) - qry = sa.select([sqla_column]).limit(1).select_from(tbl) - sql = self.database.compile_sqla_query(qry) - col_desc = get_columns_description(self.database, sql) - is_dttm = col_desc[0]["is_dttm"] + try: + sqla_column = literal_column(expression) + # probe adhoc column type + tbl, _ = self.get_from_clause(template_processor) + qry = sa.select([sqla_column]).limit(1).select_from(tbl) + sql = self.database.compile_sqla_query(qry) + col_desc = get_columns_description(self.database, sql) + is_dttm = col_desc[0]["is_dttm"] + except SupersetGenericDBErrorException as ex: + raise ColumnNotFoundException(message=str(ex)) from ex if ( is_dttm @@ -1185,6 +1196,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma } columns = columns or [] groupby = groupby or [] + rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] + applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] series_column_names = utils.get_column_names(series_columns or []) # deprecated, to be removed in 2.0 if is_timeseries and timeseries_limit: @@ -1443,9 +1456,14 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col elif is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(flt_col) + try: + sqla_col = self.adhoc_column_to_sqla(flt_col) + applied_adhoc_filters_columns.append(flt_col) + except ColumnNotFoundException: + rejected_adhoc_filters_columns.append(flt_col) + continue else: - col_obj = columns_by_name.get(flt_col) + col_obj = columns_by_name.get(cast(str, flt_col)) filter_grain = flt.get("grain") if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): @@ -1770,8 +1788,27 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma qry = select([col]).select_from(qry.alias("rowcount_qry")) labels_expected = [label] + filter_columns = [flt.get("col") for flt in filter] if filter else [] + rejected_filter_columns = [ + col + for col in filter_columns + if col + and not is_adhoc_column(col) + and col not in self.column_names + and col not in applied_template_filters + ] + rejected_adhoc_filters_columns + applied_filter_columns = [ + col + for col in filter_columns + if col + and not is_adhoc_column(col) + and (col in self.column_names or col in applied_template_filters) + ] + applied_adhoc_filters_columns + return SqlaQuery( applied_template_filters=applied_template_filters, + rejected_filter_columns=rejected_filter_columns, + applied_filter_columns=applied_filter_columns, cte=cte, extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, @@ -1910,6 +1947,8 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: return QueryResult( applied_template_filters=query_str_ext.applied_template_filters, + applied_filter_columns=query_str_ext.applied_filter_columns, + rejected_filter_columns=query_str_ext.rejected_filter_columns, status=status, df=df, duration=datetime.now() - qry_start_dttm, diff --git a/superset/exceptions.py b/superset/exceptions.py index 963bf966820d5..cee15be376394 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -270,3 +270,7 @@ class SupersetCancelQueryException(SupersetException): class QueryNotFoundException(SupersetException): status = 404 + + +class ColumnNotFoundException(SupersetException): + status = 404 diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 15b7a420a079e..bce997088495b 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -80,6 +80,7 @@ from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, sanitize_clause from superset.superset_typing import ( AdhocMetric, + Column as ColumnTyping, FilterValue, FilterValues, Metric, @@ -545,6 +546,8 @@ def __init__( # pylint: disable=too-many-arguments query: str, duration: timedelta, applied_template_filters: Optional[List[str]] = None, + applied_filter_columns: Optional[List[ColumnTyping]] = None, + rejected_filter_columns: Optional[List[ColumnTyping]] = None, status: str = QueryStatus.SUCCESS, error_message: Optional[str] = None, errors: Optional[List[Dict[str, Any]]] = None, @@ -555,6 +558,8 @@ def __init__( # pylint: disable=too-many-arguments self.query = query self.duration = duration self.applied_template_filters = applied_template_filters or [] + self.applied_filter_columns = applied_filter_columns or [] + self.rejected_filter_columns = rejected_filter_columns or [] self.status = status self.error_message = error_message self.errors = errors or [] @@ -1646,7 +1651,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma elif utils.is_adhoc_column(flt_col): sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore else: - col_obj = columns_by_name.get(flt_col) + col_obj = columns_by_name.get(cast(str, flt_col)) filter_grain = flt.get("grain") if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): diff --git a/superset/utils/core.py b/superset/utils/core.py index 6f86372f753f6..22ff3f8be4be2 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -221,7 +221,7 @@ class AdhocFilterClause(TypedDict, total=False): class QueryObjectFilterClause(TypedDict, total=False): - col: str + col: Column op: str # pylint: disable=invalid-name val: Optional[FilterValues] grain: Optional[str] @@ -1089,7 +1089,7 @@ def simple_filter_to_adhoc( "expressionType": "SIMPLE", "comparator": filter_clause.get("val"), "operator": filter_clause["op"], - "subject": filter_clause["col"], + "subject": cast(str, filter_clause["col"]), } if filter_clause.get("isExtra"): result["isExtra"] = True diff --git a/superset/viz.py b/superset/viz.py index 1f4c795325b4b..5150c42d3c551 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -154,7 +154,8 @@ def __init__( self.status: Optional[str] = None self.error_msg = "" self.results: Optional[QueryResult] = None - self.applied_template_filters: List[str] = [] + self.applied_filter_columns: List[Column] = [] + self.rejected_filter_columns: List[Column] = [] self.errors: List[Dict[str, Any]] = [] self.force = force self._force_cached = force_cached @@ -288,7 +289,8 @@ def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame: # The datasource here can be different backend but the interface is common self.results = self.datasource.query(query_obj) - self.applied_template_filters = self.results.applied_template_filters or [] + self.applied_filter_columns = self.results.applied_filter_columns or [] + self.rejected_filter_columns = self.results.rejected_filter_columns or [] self.query = self.results.query self.status = self.results.status self.errors = self.results.errors @@ -492,25 +494,21 @@ def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload if "df" in payload: del payload["df"] - filters = self.form_data.get("filters", []) - filter_columns = [flt.get("col") for flt in filters] - columns = set(self.datasource.column_names) - applied_template_filters = self.applied_template_filters or [] + applied_filter_columns = self.applied_filter_columns or [] + rejected_filter_columns = self.rejected_filter_columns or [] applied_time_extras = self.form_data.get("applied_time_extras", {}) applied_time_columns, rejected_time_columns = utils.get_time_filter_status( self.datasource, applied_time_extras ) payload["applied_filters"] = [ - {"column": get_column_name(col)} - for col in filter_columns - if is_adhoc_column(col) or col in columns or col in applied_template_filters + {"column": get_column_name(col)} for col in applied_filter_columns ] + applied_time_columns payload["rejected_filters"] = [ - {"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col} - for col in filter_columns - if not is_adhoc_column(col) - and col not in columns - and col not in applied_template_filters + { + "reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, + "column": get_column_name(col), + } + for col in rejected_filter_columns ] + rejected_time_columns if df is not None: payload["colnames"] = list(df.columns) @@ -535,8 +533,11 @@ def get_df_payload( # pylint: disable=too-many-statements try: df = cache_value["df"] self.query = cache_value["query"] - self.applied_template_filters = cache_value.get( - "applied_template_filters", [] + self.applied_filter_columns = cache_value.get( + "applied_filter_columns", [] + ) + self.rejected_filter_columns = cache_value.get( + "rejected_filter_columns", [] ) self.status = QueryStatus.SUCCESS is_loaded = True diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 66151362ff1d4..83fb7281fbc74 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -56,6 +56,7 @@ AnnotationType, get_example_default_schema, AdhocMetricExpressionType, + ExtraFiltersReasonType, ) from superset.utils.database import get_example_database, get_main_database from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType @@ -73,6 +74,12 @@ "when gender = 'girl' then 'female' else 'other' end", } +INCOMPATIBLE_ADHOC_COLUMN_FIXTURE: AdhocColumn = { + "hasCustomLabel": True, + "label": "exciting_or_boring", + "sqlExpression": "case when genre = 'Action' then 'Exciting' else 'Boring' end", +} + class BaseTestChartDataApi(SupersetTestCase): query_context_payload_template = None @@ -1059,6 +1066,33 @@ def test_chart_data_with_adhoc_column(self): assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_with_incompatible_adhoc_column(self): + """ + Chart data API: Test query with adhoc column that fails to run on this dataset + """ + self.login(username="admin") + request_payload = get_query_context("birth_names") + request_payload["queries"][0]["columns"] = [ADHOC_COLUMN_FIXTURE] + request_payload["queries"][0]["filters"] = [ + {"col": INCOMPATIBLE_ADHOC_COLUMN_FIXTURE, "op": "IN", "val": ["Exciting"]}, + {"col": ADHOC_COLUMN_FIXTURE, "op": "IN", "val": ["male", "female"]}, + ] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + data = result["data"] + assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} + unique_genders = {row["male_or_female"] for row in data} + assert unique_genders == {"male", "female"} + assert result["applied_filters"] == [{"column": "male_or_female"}] + assert result["rejected_filters"] == [ + { + "column": "exciting_or_boring", + "reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, + } + ] + @pytest.fixture() def physical_query_context(physical_dataset) -> Dict[str, Any]: