From f3ca3bd298ea012b402d949ee8eb8e178fd7b61e Mon Sep 17 00:00:00 2001 From: ofekisr <35701650+ofekisr@users.noreply.github.com> Date: Thu, 18 Nov 2021 19:53:56 +0200 Subject: [PATCH] refactor(QueryObject): decouple from superset (#17479) * refactor: queryObject - decouple from superset * refactor: queryObject - decouple from superset --- superset/common/query_context.py | 9 +- superset/common/query_object.py | 81 +----- superset/common/query_object_factory.py | 134 +++++++++ tests/common/__init__.py | 16 ++ tests/common/query_context_generator.py | 259 ++++++++++++++++++ .../charts/data/api_tests.py | 7 +- .../integration_tests/charts/schema_tests.py | 27 +- tests/integration_tests/core_tests.py | 1 - .../fixtures/query_context.py | 225 +-------------- tests/unit_tests/common/__init__.py | 16 ++ .../common/test_query_object_factory.py | 116 ++++++++ 11 files changed, 562 insertions(+), 329 deletions(-) create mode 100644 superset/common/query_object_factory.py create mode 100644 tests/common/__init__.py create mode 100644 tests/common/query_context_generator.py create mode 100644 tests/unit_tests/common/__init__.py create mode 100644 tests/unit_tests/common/test_query_object_factory.py diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 6dc28708eaa63..8f17db985ed77 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -32,7 +32,8 @@ from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.db_query_status import QueryStatus from superset.common.query_actions import get_query_results -from superset.common.query_object import QueryObject, QueryObjectFactory +from superset.common.query_object import QueryObject +from superset.common.query_object_factory import QueryObjectFactory from superset.common.utils import QueryCacheManager from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry @@ -69,6 +70,10 @@ class CachedTimeOffset(TypedDict): cache_keys: List[Optional[str]] +def create_query_object_factory() -> QueryObjectFactory: + return QueryObjectFactory(config, ConnectorRegistry(), db.session) + + class QueryContext: """ The query context contains the query object and additional fields necessary @@ -102,7 +107,7 @@ def __init__( ) self.result_type = result_type or ChartDataResultType.FULL self.result_format = result_format or ChartDataResultFormat.JSON - query_object_factory = QueryObjectFactory() + query_object_factory = create_query_object_factory() self.queries = [ query_object_factory.create(self.result_type, **query_obj) for query_obj in queries diff --git a/superset/common/query_object.py b/superset/common/query_object.py index e7480370260d3..ff1ad710ee40e 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -14,25 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, no-self-use +# pylint: disable=invalid-name from __future__ import annotations import logging from datetime import datetime, timedelta -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING from flask_babel import gettext as _ from pandas import DataFrame -from superset import app, db from superset.common.chart_data import ChartDataResultType -from superset.connectors.connector_registry import ConnectorRegistry from superset.exceptions import QueryObjectValidationError from superset.typing import Column, Metric, OrderBy from superset.utils import pandas_postprocessing from superset.utils.core import ( - apply_max_row_limit, - DatasourceDict, DTTM_ALIAS, find_duplicates, get_column_names, @@ -41,15 +37,12 @@ json_int_dttm_ser, QueryObjectFilterClause, ) -from superset.utils.date_parser import get_since_until, parse_human_timedelta +from superset.utils.date_parser import parse_human_timedelta from superset.utils.hashing import md5_sha_from_dict -from superset.views.utils import get_time_range_endpoints if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource - -config = app.config logger = logging.getLogger(__name__) # TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type @@ -404,71 +397,3 @@ def exec_post_processing(self, df: DataFrame) -> DataFrame: options = post_process.get("options", {}) df = getattr(pandas_postprocessing, operation)(df, **options) return df - - -class QueryObjectFactory: # pylint: disable=too-few-public-methods - def create( # pylint: disable=too-many-arguments - self, - parent_result_type: ChartDataResultType, - datasource: Optional[DatasourceDict] = None, - extras: Optional[Dict[str, Any]] = None, - row_limit: Optional[int] = None, - time_range: Optional[str] = None, - time_shift: Optional[str] = None, - **kwargs: Any, - ) -> QueryObject: - datasource_model_instance = None - if datasource: - datasource_model_instance = self._convert_to_model(datasource) - processed_extras = self._process_extras(extras) - result_type = kwargs.setdefault("result_type", parent_result_type) - row_limit = self._process_row_limit(row_limit, result_type) - from_dttm, to_dttm = self._get_dttms(time_range, time_shift, processed_extras) - kwargs["from_dttm"] = from_dttm - kwargs["to_dttm"] = to_dttm - return QueryObject( - datasource=datasource_model_instance, - extras=extras, - row_limit=row_limit, - time_range=time_range, - time_shift=time_shift, - **kwargs, - ) - - def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: - return ConnectorRegistry.get_datasource( - str(datasource["type"]), int(datasource["id"]), db.session - ) - - def _process_extras(self, extras: Optional[Dict[str, Any]]) -> Dict[str, Any]: - extras = extras or {} - if config["SIP_15_ENABLED"]: - extras["time_range_endpoints"] = get_time_range_endpoints(form_data=extras) - return extras - - def _process_row_limit( - self, row_limit: Optional[int], result_type: ChartDataResultType - ) -> int: - default_row_limit = ( - config["SAMPLES_ROW_LIMIT"] - if result_type == ChartDataResultType.SAMPLES - else config["ROW_LIMIT"] - ) - return apply_max_row_limit(row_limit or default_row_limit) - - def _get_dttms( - self, - time_range: Optional[str], - time_shift: Optional[str], - extras: Dict[str, Any], - ) -> Tuple[Optional[datetime], Optional[datetime]]: - return get_since_until( - relative_start=extras.get( - "relative_start", config["DEFAULT_RELATIVE_START_TIME"] - ), - relative_end=extras.get( - "relative_end", config["DEFAULT_RELATIVE_END_TIME"] - ), - time_range=time_range, - time_shift=time_shift, - ) diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py new file mode 100644 index 0000000000000..d19be733739da --- /dev/null +++ b/superset/common/query_object_factory.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import date, datetime +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING + +from superset.common.chart_data import ChartDataResultType +from superset.common.query_object import QueryObject +from superset.utils.core import apply_max_row_limit, DatasourceDict, TimeRangeEndpoint +from superset.utils.date_parser import get_since_until + +if TYPE_CHECKING: + from sqlalchemy.orm import sessionmaker + + from superset import ConnectorRegistry + from superset.connectors.base.models import BaseDatasource + + +class QueryObjectFactory: # pylint: disable=too-few-public-methods + _config: Dict[str, Any] + _connector_registry: ConnectorRegistry + _session_maker: sessionmaker + + def __init__( + self, + app_configurations: Dict[str, Any], + connector_registry: ConnectorRegistry, + session_maker: sessionmaker, + ): + self._config = app_configurations + self._connector_registry = connector_registry + self._session_maker = session_maker + + def create( # pylint: disable=too-many-arguments + self, + parent_result_type: ChartDataResultType, + datasource: Optional[DatasourceDict] = None, + extras: Optional[Dict[str, Any]] = None, + row_limit: Optional[int] = None, + time_range: Optional[str] = None, + time_shift: Optional[str] = None, + **kwargs: Any, + ) -> QueryObject: + datasource_model_instance = None + if datasource: + datasource_model_instance = self._convert_to_model(datasource) + processed_extras = self._process_extras(extras) + result_type = kwargs.setdefault("result_type", parent_result_type) + row_limit = self._process_row_limit(row_limit, result_type) + from_dttm, to_dttm = self._get_dttms(time_range, time_shift, processed_extras) + kwargs["from_dttm"] = from_dttm + kwargs["to_dttm"] = to_dttm + return QueryObject( + datasource=datasource_model_instance, + extras=extras, + row_limit=row_limit, + time_range=time_range, + time_shift=time_shift, + **kwargs, + ) + + def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: + return self._connector_registry.get_datasource( + str(datasource["type"]), int(datasource["id"]), self._session_maker() + ) + + def _process_extras(self, extras: Optional[Dict[str, Any]]) -> Dict[str, Any]: + extras = extras or {} + if self._config["SIP_15_ENABLED"]: + extras["time_range_endpoints"] = self._determine_time_range_endpoints( + extras.get("time_range_endpoints") + ) + return extras + + def _process_row_limit( + self, row_limit: Optional[int], result_type: ChartDataResultType + ) -> int: + default_row_limit = ( + self._config["SAMPLES_ROW_LIMIT"] + if result_type == ChartDataResultType.SAMPLES + else self._config["ROW_LIMIT"] + ) + return apply_max_row_limit(row_limit or default_row_limit) + + def _get_dttms( + self, + time_range: Optional[str], + time_shift: Optional[str], + extras: Dict[str, Any], + ) -> Tuple[Optional[datetime], Optional[datetime]]: + return get_since_until( + relative_start=extras.get( + "relative_start", self._config["DEFAULT_RELATIVE_START_TIME"] + ), + relative_end=extras.get( + "relative_end", self._config["DEFAULT_RELATIVE_END_TIME"] + ), + time_range=time_range, + time_shift=time_shift, + ) + + # light version of the view.utils.core + # import view.utils require application context + # Todo: move it and the view.utils.core to utils package + + def _determine_time_range_endpoints( + self, raw_endpoints: Optional[Tuple[str, str]] = None, + ) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]: + if ( + self._config["SIP_15_GRACE_PERIOD_END"] + and date.today() >= self._config["SIP_15_GRACE_PERIOD_END"] + ): + return TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE + + if raw_endpoints: + start, end = raw_endpoints + return TimeRangeEndpoint(start), TimeRangeEndpoint(end) + + return TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE diff --git a/tests/common/__init__.py b/tests/common/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/common/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/common/query_context_generator.py b/tests/common/query_context_generator.py new file mode 100644 index 0000000000000..69bafc175d9ba --- /dev/null +++ b/tests/common/query_context_generator.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import copy +import dataclasses +from typing import Any, Dict, List + +from superset.common.chart_data import ChartDataResultType +from superset.utils.core import AnnotationType, DTTM_ALIAS, TimeRangeEndpoint + +query_birth_names = { + "extras": { + "where": "", + "time_range_endpoints": ( + TimeRangeEndpoint.INCLUSIVE, + TimeRangeEndpoint.EXCLUSIVE, + ), + "time_grain_sqla": "P1D", + }, + "columns": ["name"], + "metrics": [{"label": "sum__num"}], + "orderby": [("sum__num", False)], + "row_limit": 100, + "granularity": "ds", + "time_range": "100 years ago : now", + "timeseries_limit": 0, + "timeseries_limit_metric": None, + "order_desc": True, + "filters": [ + {"col": "gender", "op": "==", "val": "boy"}, + {"col": "num", "op": "IS NOT NULL"}, + {"col": "name", "op": "NOT IN", "val": ["", '"abc"']}, + ], + "having": "", + "having_filters": [], + "where": "", +} + +QUERY_OBJECTS: Dict[str, Dict[str, object]] = { + "birth_names": query_birth_names, + # `:suffix` are overrides only + "birth_names:include_time": {"groupby": [DTTM_ALIAS, "name"],}, + "birth_names:orderby_dup_alias": { + "metrics": [ + { + "expressionType": "SIMPLE", + "column": {"column_name": "num_girls", "type": "BIGINT(20)"}, + "aggregate": "SUM", + "label": "num_girls", + }, + { + "expressionType": "SIMPLE", + "column": {"column_name": "num_boys", "type": "BIGINT(20)"}, + "aggregate": "SUM", + "label": "num_boys", + }, + ], + "orderby": [ + [ + { + "expressionType": "SIMPLE", + "column": {"column_name": "num_girls", "type": "BIGINT(20)"}, + "aggregate": "SUM", + # the same underlying expression, but different label + "label": "SUM(num_girls)", + }, + False, + ], + # reference the ambiguous alias in SIMPLE metric + [ + { + "expressionType": "SIMPLE", + "column": {"column_name": "num_boys", "type": "BIGINT(20)"}, + "aggregate": "AVG", + "label": "AVG(num_boys)", + }, + False, + ], + # reference the ambiguous alias in CUSTOM SQL metric + [ + { + "expressionType": "SQL", + "sqlExpression": "MAX(CASE WHEN num_boys > 0 THEN 1 ELSE 0 END)", + "label": "MAX(CASE WHEN...", + }, + True, + ], + ], + }, + "birth_names:only_orderby_has_metric": {"metrics": [],}, +} + +ANNOTATION_LAYERS = { + AnnotationType.FORMULA: { + "annotationType": "FORMULA", + "color": "#ff7f44", + "hideLine": False, + "name": "my formula", + "opacity": "", + "overrides": {"time_range": None}, + "show": True, + "showMarkers": False, + "sourceType": "", + "style": "solid", + "value": "3+x", + "width": 5, + }, + AnnotationType.EVENT: { + "name": "my event", + "annotationType": "EVENT", + "sourceType": "NATIVE", + "color": "#e04355", + "opacity": "", + "style": "solid", + "width": 5, + "showMarkers": False, + "hideLine": False, + "value": 1, + "overrides": {"time_range": None}, + "show": True, + "titleColumn": "", + "descriptionColumns": [], + "timeColumn": "", + "intervalEndColumn": "", + }, + AnnotationType.INTERVAL: { + "name": "my interval", + "annotationType": "INTERVAL", + "sourceType": "NATIVE", + "color": "#e04355", + "opacity": "", + "style": "solid", + "width": 1, + "showMarkers": False, + "hideLine": False, + "value": 1, + "overrides": {"time_range": None}, + "show": True, + "titleColumn": "", + "descriptionColumns": [], + "timeColumn": "", + "intervalEndColumn": "", + }, + AnnotationType.TIME_SERIES: { + "annotationType": "TIME_SERIES", + "color": None, + "descriptionColumns": [], + "hideLine": False, + "intervalEndColumn": "", + "name": "my line", + "opacity": "", + "overrides": {"time_range": None}, + "show": True, + "showMarkers": False, + "sourceType": "line", + "style": "dashed", + "timeColumn": "", + "titleColumn": "", + "value": 837, + "width": 5, + }, +} + +POSTPROCESSING_OPERATIONS = { + "birth_names": [ + { + "operation": "aggregate", + "options": { + "groupby": ["gender"], + "aggregates": { + "q1": { + "operator": "percentile", + "column": "sum__num", + "options": {"q": 25}, + }, + "median": {"operator": "median", "column": "sum__num",}, + }, + }, + }, + {"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},}, + ] +} + + +def get_query_object( + query_name: str, add_postprocessing_operations: bool, add_time_offsets: bool, +) -> Dict[str, Any]: + if query_name not in QUERY_OBJECTS: + raise Exception(f"QueryObject fixture not defined for datasource: {query_name}") + obj = QUERY_OBJECTS[query_name] + + # apply overrides + if ":" in query_name: + parent_query_name = query_name.split(":")[0] + obj = { + **QUERY_OBJECTS[parent_query_name], + **obj, + } + + query_object = copy.deepcopy(obj) + if add_postprocessing_operations: + query_object["post_processing"] = _get_postprocessing_operation(query_name) + if add_time_offsets: + query_object["time_offsets"] = ["1 year ago"] + + return query_object + + +def _get_postprocessing_operation(query_name: str) -> List[Dict[str, Any]]: + if query_name not in QUERY_OBJECTS: + raise Exception( + f"Post-processing fixture not defined for datasource: {query_name}" + ) + return copy.deepcopy(POSTPROCESSING_OPERATIONS[query_name]) + + +@dataclasses.dataclass +class Table: + id: int + type: str + name: str + + +class QueryContextGenerator: + def generate( + self, + query_name: str, + add_postprocessing_operations: bool = False, + add_time_offsets: bool = False, + table_id=1, + table_type="table", + ) -> Dict[str, Any]: + table_name = query_name.split(":")[0] + table = self.get_table(table_name, table_id, table_type) + return { + "datasource": {"id": table.id, "type": table.type}, + "queries": [ + get_query_object( + query_name, add_postprocessing_operations, add_time_offsets, + ) + ], + "result_type": ChartDataResultType.FULL, + } + + def get_table(self, name, id, type): + return Table(id, type, name) diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index bc30019d34832..9a93cb479b5e1 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -53,11 +53,8 @@ ) from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType - -from tests.integration_tests.fixtures.query_context import ( - get_query_context, - ANNOTATION_LAYERS, -) +from tests.common.query_context_generator import ANNOTATION_LAYERS +from tests.integration_tests.fixtures.query_context import get_query_context CHART_DATA_URI = "api/v1/chart/data" diff --git a/tests/integration_tests/charts/schema_tests.py b/tests/integration_tests/charts/schema_tests.py index 977cf72957396..95e1c07ef4805 100644 --- a/tests/integration_tests/charts/schema_tests.py +++ b/tests/integration_tests/charts/schema_tests.py @@ -32,29 +32,13 @@ class TestSchema(SupersetTestCase): @mock.patch( - "superset.common.query_object.config", {**app.config, "ROW_LIMIT": 5000}, + "superset.common.query_context.config", {**app.config, "ROW_LIMIT": 5000}, ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_context_limit_and_offset(self): self.login(username="admin") payload = get_query_context("birth_names") - # Use defaults - payload["queries"][0].pop("row_limit", None) - payload["queries"][0].pop("row_offset", None) - query_context = ChartDataQueryContextSchema().load(payload) - query_object = query_context.queries[0] - self.assertEqual(query_object.row_limit, 5000) - self.assertEqual(query_object.row_offset, 0) - - # Valid limit and offset - payload["queries"][0]["row_limit"] = 100 - payload["queries"][0]["row_offset"] = 200 - query_context = ChartDataQueryContextSchema().load(payload) - query_object = query_context.queries[0] - self.assertEqual(query_object.row_limit, 100) - self.assertEqual(query_object.row_offset, 200) - # too low limit and offset payload["queries"][0]["row_limit"] = -1 payload["queries"][0]["row_offset"] = -1 @@ -91,12 +75,3 @@ def test_query_context_series_limit(self): "label": "COUNT_DISTINCT(gender)", } _ = ChartDataQueryContextSchema().load(payload) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_query_context_null_post_processing_op(self): - self.login(username="admin") - payload = get_query_context("birth_names") - - payload["queries"][0]["post_processing"] = [None] - query_context = ChartDataQueryContextSchema().load(payload) - self.assertEqual(query_context.queries[0].post_processing, []) diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index e0614f2ddd5ef..3b9a3976fc14f 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -749,7 +749,6 @@ def test_templated_sql_json(self): data = self.run_sql(sql, "fdaklj3ws") self.assertEqual(data["data"][0]["test"], "2") - @pytest.mark.ofek @mock.patch( "tests.integration_tests.superset_test_custom_template_processors.datetime" ) diff --git a/tests/integration_tests/fixtures/query_context.py b/tests/integration_tests/fixtures/query_context.py index d36a01087753f..40892e75738d0 100644 --- a/tests/integration_tests/fixtures/query_context.py +++ b/tests/integration_tests/fixtures/query_context.py @@ -14,216 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import copy -from typing import Any, Dict, List +from typing import Any, Dict -from superset.utils.core import AnnotationType, DTTM_ALIAS, TimeRangeEndpoint +from tests.common.query_context_generator import QueryContextGenerator from tests.integration_tests.base_tests import SupersetTestCase -query_birth_names = { - "extras": { - "where": "", - "time_range_endpoints": ( - TimeRangeEndpoint.INCLUSIVE, - TimeRangeEndpoint.EXCLUSIVE, - ), - "time_grain_sqla": "P1D", - }, - "columns": ["name"], - "metrics": [{"label": "sum__num"}], - "orderby": [("sum__num", False)], - "row_limit": 100, - "granularity": "ds", - "time_range": "100 years ago : now", - "timeseries_limit": 0, - "timeseries_limit_metric": None, - "order_desc": True, - "filters": [ - {"col": "gender", "op": "==", "val": "boy"}, - {"col": "num", "op": "IS NOT NULL"}, - {"col": "name", "op": "NOT IN", "val": ["", '"abc"']}, - ], - "having": "", - "having_filters": [], - "where": "", -} -QUERY_OBJECTS: Dict[str, Dict[str, object]] = { - "birth_names": query_birth_names, - # `:suffix` are overrides only - "birth_names:include_time": {"groupby": [DTTM_ALIAS, "name"],}, - "birth_names:orderby_dup_alias": { - "metrics": [ - { - "expressionType": "SIMPLE", - "column": {"column_name": "num_girls", "type": "BIGINT(20)"}, - "aggregate": "SUM", - "label": "num_girls", - }, - { - "expressionType": "SIMPLE", - "column": {"column_name": "num_boys", "type": "BIGINT(20)"}, - "aggregate": "SUM", - "label": "num_boys", - }, - ], - "orderby": [ - [ - { - "expressionType": "SIMPLE", - "column": {"column_name": "num_girls", "type": "BIGINT(20)"}, - "aggregate": "SUM", - # the same underlying expression, but different label - "label": "SUM(num_girls)", - }, - False, - ], - # reference the ambiguous alias in SIMPLE metric - [ - { - "expressionType": "SIMPLE", - "column": {"column_name": "num_boys", "type": "BIGINT(20)"}, - "aggregate": "AVG", - "label": "AVG(num_boys)", - }, - False, - ], - # reference the ambiguous alias in CUSTOM SQL metric - [ - { - "expressionType": "SQL", - "sqlExpression": "MAX(CASE WHEN num_boys > 0 THEN 1 ELSE 0 END)", - "label": "MAX(CASE WHEN...", - }, - True, - ], - ], - }, - "birth_names:only_orderby_has_metric": {"metrics": [],}, -} - -ANNOTATION_LAYERS = { - AnnotationType.FORMULA: { - "annotationType": "FORMULA", - "color": "#ff7f44", - "hideLine": False, - "name": "my formula", - "opacity": "", - "overrides": {"time_range": None}, - "show": True, - "showMarkers": False, - "sourceType": "", - "style": "solid", - "value": "3+x", - "width": 5, - }, - AnnotationType.EVENT: { - "name": "my event", - "annotationType": "EVENT", - "sourceType": "NATIVE", - "color": "#e04355", - "opacity": "", - "style": "solid", - "width": 5, - "showMarkers": False, - "hideLine": False, - "value": 1, - "overrides": {"time_range": None}, - "show": True, - "titleColumn": "", - "descriptionColumns": [], - "timeColumn": "", - "intervalEndColumn": "", - }, - AnnotationType.INTERVAL: { - "name": "my interval", - "annotationType": "INTERVAL", - "sourceType": "NATIVE", - "color": "#e04355", - "opacity": "", - "style": "solid", - "width": 1, - "showMarkers": False, - "hideLine": False, - "value": 1, - "overrides": {"time_range": None}, - "show": True, - "titleColumn": "", - "descriptionColumns": [], - "timeColumn": "", - "intervalEndColumn": "", - }, - AnnotationType.TIME_SERIES: { - "annotationType": "TIME_SERIES", - "color": None, - "descriptionColumns": [], - "hideLine": False, - "intervalEndColumn": "", - "name": "my line", - "opacity": "", - "overrides": {"time_range": None}, - "show": True, - "showMarkers": False, - "sourceType": "line", - "style": "dashed", - "timeColumn": "", - "titleColumn": "", - "value": 837, - "width": 5, - }, -} - -POSTPROCESSING_OPERATIONS = { - "birth_names": [ - { - "operation": "aggregate", - "options": { - "groupby": ["gender"], - "aggregates": { - "q1": { - "operator": "percentile", - "column": "sum__num", - "options": {"q": 25}, - }, - "median": {"operator": "median", "column": "sum__num",}, - }, - }, - }, - {"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},}, - ] -} - - -def get_query_object( - query_name: str, add_postprocessing_operations: bool, add_time_offsets: bool, -) -> Dict[str, Any]: - if query_name not in QUERY_OBJECTS: - raise Exception(f"QueryObject fixture not defined for datasource: {query_name}") - obj = QUERY_OBJECTS[query_name] - - # apply overrides - if ":" in query_name: - parent_query_name = query_name.split(":")[0] - obj = { - **QUERY_OBJECTS[parent_query_name], - **obj, - } - - query_object = copy.deepcopy(obj) - if add_postprocessing_operations: - query_object["post_processing"] = _get_postprocessing_operation(query_name) - if add_time_offsets: - query_object["time_offsets"] = ["1 year ago"] - - return query_object - - -def _get_postprocessing_operation(query_name: str) -> List[Dict[str, Any]]: - if query_name not in QUERY_OBJECTS: - raise Exception( - f"Post-processing fixture not defined for datasource: {query_name}" - ) - return copy.deepcopy(POSTPROCESSING_OPERATIONS[query_name]) +class QueryContextGeneratorInteg(QueryContextGenerator): + def get_table(self, name, id, type): + return SupersetTestCase.get_table(name=name) def get_query_context( @@ -235,7 +34,6 @@ def get_query_context( Create a request payload for retrieving a QueryContext object via the `api/v1/chart/data` endpoint. By default returns a payload corresponding to one generated by the "Boy Name Cloud" chart in the examples. - :param query_name: name of an example query, which is always in the format of `datasource_name[:test_case_name]`, where `:test_case_name` is optional. :param datasource_id: id of datasource to query. @@ -244,13 +42,6 @@ def get_query_context( :param add_time_offsets: Add time offsets to QueryObject(advanced analytics) :return: Request payload """ - table_name = query_name.split(":")[0] - table = SupersetTestCase.get_table(name=table_name) - return { - "datasource": {"id": table.id, "type": table.type}, - "queries": [ - get_query_object( - query_name, add_postprocessing_operations, add_time_offsets, - ) - ], - } + return QueryContextGeneratorInteg().generate( + query_name, add_postprocessing_operations, add_time_offsets + ) diff --git a/tests/unit_tests/common/__init__.py b/tests/unit_tests/common/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/common/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py new file mode 100644 index 0000000000000..4e10fcc3c2d4c --- /dev/null +++ b/tests/unit_tests/common/test_query_object_factory.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, Optional +from unittest.mock import Mock, patch + +from pytest import fixture, mark + +from superset.common.query_object_factory import QueryObjectFactory +from tests.common.query_context_generator import QueryContextGenerator + + +def create_app_config() -> Dict[str, Any]: + return { + "ROW_LIMIT": 5000, + "DEFAULT_RELATIVE_START_TIME": "today", + "DEFAULT_RELATIVE_END_TIME": "today", + "SAMPLES_ROW_LIMIT": 1000, + "SIP_15_ENABLED": True, + "SQL_MAX_ROW": 100000, + "SIP_15_GRACE_PERIOD_END": None, + } + + +@fixture +def app_config() -> Dict[str, Any]: + return create_app_config().copy() + + +@fixture +def session_factory() -> Mock: + return Mock() + + +@fixture +def connector_registry() -> Mock: + return Mock(spec=["get_datasource"]) + + +def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int: + if max_limit is None: + max_limit = create_app_config()["SQL_MAX_ROW"] + if limit != 0: + return min(max_limit, limit) + return max_limit + + +@fixture +def query_object_factory( + app_config: Dict[str, Any], connector_registry: Mock, session_factory: Mock +) -> QueryObjectFactory: + import superset.common.query_object_factory as mod + + mod.apply_max_row_limit = apply_max_row_limit + return QueryObjectFactory(app_config, connector_registry, session_factory) + + +@fixture +def raw_query_context() -> Dict[str, Any]: + return QueryContextGenerator().generate("birth_names") + + +class TestQueryObjectFactory: + def test_query_context_limit_and_offset_defaults( + self, + query_object_factory: QueryObjectFactory, + raw_query_context: Dict[str, Any], + ): + raw_query_object = raw_query_context["queries"][0] + raw_query_object.pop("row_limit", None) + raw_query_object.pop("row_offset", None) + query_object = query_object_factory.create( + raw_query_context["result_type"], **raw_query_object + ) + assert query_object.row_limit == 5000 + assert query_object.row_offset == 0 + + def test_query_context_limit( + self, + query_object_factory: QueryObjectFactory, + raw_query_context: Dict[str, Any], + ): + raw_query_object = raw_query_context["queries"][0] + raw_query_object["row_limit"] = 100 + raw_query_object["row_offset"] = 200 + query_object = query_object_factory.create( + raw_query_context["result_type"], **raw_query_object + ) + + assert query_object.row_limit == 100 + assert query_object.row_offset == 200 + + def test_query_context_null_post_processing_op( + self, + query_object_factory: QueryObjectFactory, + raw_query_context: Dict[str, Any], + ): + raw_query_object = raw_query_context["queries"][0] + raw_query_object["post_processing"] = [None] + query_object = query_object_factory.create( + raw_query_context["result_type"], **raw_query_object + ) + assert query_object.post_processing == []