Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(QueryObject): add QueryObjectFactory to meet SRP #17466

Merged
merged 1 commit into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.query_object import QueryObject
from superset.common.query_object import QueryObject, QueryObjectFactory
from superset.common.utils import QueryCacheManager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
Expand Down Expand Up @@ -102,8 +102,10 @@ def __init__(
)
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
query_object_factory = QueryObjectFactory()
self.queries = [
QueryObject(self.result_type, **query_obj) for query_obj in queries
query_object_factory.create(self.result_type, **query_obj)
for query_obj in queries
]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
Expand Down
133 changes: 83 additions & 50 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name, no-self-use
from __future__ import annotations

import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING

from flask_babel import gettext as _
from pandas import DataFrame

from superset import app, db
from superset.common.chart_data import ChartDataResultType
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import QueryObjectValidationError
from superset.typing import Column, Metric, OrderBy
Expand All @@ -47,7 +46,7 @@
from superset.views.utils import get_time_range_endpoints

if TYPE_CHECKING:
from superset.common.query_context import QueryContext # pragma: no cover
from superset.connectors.base.models import BaseDatasource


config = app.config
Expand Down Expand Up @@ -111,14 +110,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
time_range: Optional[str]
to_dttm: Optional[datetime]

def __init__( # pylint: disable=too-many-arguments,too-many-locals
def __init__( # pylint: disable=too-many-locals
self,
parent_result_type: ChartDataResultType,
*,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
columns: Optional[List[Column]] = None,
datasource: Optional[DatasourceDict] = None,
datasource: Optional[BaseDatasource] = None,
extras: Optional[Dict[str, Any]] = None,
filters: Optional[List[QueryObjectFilterClause]] = None,
granularity: Optional[str] = None,
Expand All @@ -128,7 +127,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
order_desc: bool = True,
orderby: Optional[List[OrderBy]] = None,
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
row_limit: Optional[int] = None,
row_limit: int,
row_offset: Optional[int] = None,
series_columns: Optional[List[Column]] = None,
series_limit: int = 0,
Expand All @@ -137,13 +136,12 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
time_shift: Optional[str] = None,
**kwargs: Any,
):
self.result_type = kwargs.get("result_type", parent_result_type)
self._set_annotation_layers(annotation_layers)
self.applied_time_extras = applied_time_extras or {}
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.columns = columns or []
self._set_datasource(datasource)
self._set_extras(extras)
self.datasource = datasource
self.extras = extras or {}
self.filter = filters or []
self.granularity = granularity
self.is_rowcount = is_rowcount
Expand All @@ -152,14 +150,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self.order_desc = order_desc
self.orderby = orderby or []
self._set_post_processing(post_processing)
self._set_row_limit(row_limit)
self.row_limit = row_limit
self.row_offset = row_offset or 0
self._init_series_columns(series_columns, metrics, is_timeseries)
self.series_limit = series_limit
self.series_limit_metric = series_limit_metric
self.set_dttms(time_range, time_shift)
self.time_range = time_range
self.time_shift = parse_human_timedelta(time_shift)
self.from_dttm = kwargs.get("from_dttm")
self.to_dttm = kwargs.get("to_dttm")
self.result_type = kwargs.get("result_type")
self.time_offsets = kwargs.get("time_offsets", [])
self.inner_from_dttm = kwargs.get("inner_from_dttm")
self.inner_to_dttm = kwargs.get("inner_to_dttm")
Expand All @@ -176,20 +176,6 @@ def _set_annotation_layers(
if layer["annotationType"] != "FORMULA"
]

def _set_datasource(self, datasource: Optional[DatasourceDict]) -> None:
self.datasource = None
if datasource:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)

def _set_extras(self, extras: Optional[Dict[str, Any]]) -> None:
self.extras = extras or {}
if config["SIP_15_ENABLED"]:
self.extras["time_range_endpoints"] = get_time_range_endpoints(
form_data=self.extras
)

def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
Expand All @@ -212,17 +198,8 @@ def is_str_or_adhoc(metric: Metric) -> bool:
def _set_post_processing(
self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
) -> None:
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]

def _set_row_limit(self, row_limit: Optional[int]) -> None:
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if self.result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
post_processing = post_processing or []
self.post_processing = [post_proc for post_proc in post_processing if post_proc]

def _init_series_columns(
self,
Expand All @@ -237,18 +214,6 @@ def _init_series_columns(
else:
self.series_columns = []

def set_dttms(self, time_range: Optional[str], time_shift: Optional[str]) -> None:
self.from_dttm, self.to_dttm = get_since_until(
relative_start=self.extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=self.extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)

def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
# rename deprecated fields
for field in DEPRECATED_FIELDS:
Expand Down Expand Up @@ -439,3 +404,71 @@ def exec_post_processing(self, df: DataFrame) -> DataFrame:
options = post_process.get("options", {})
df = getattr(pandas_postprocessing, operation)(df, **options)
return df


class QueryObjectFactory: # pylint: disable=too-few-public-methods
def create( # pylint: disable=too-many-arguments
self,
parent_result_type: ChartDataResultType,
datasource: Optional[DatasourceDict] = None,
extras: Optional[Dict[str, Any]] = None,
row_limit: Optional[int] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
**kwargs: Any,
) -> QueryObject:
datasource_model_instance = None
if datasource:
datasource_model_instance = self._convert_to_model(datasource)
processed_extras = self._process_extras(extras)
result_type = kwargs.setdefault("result_type", parent_result_type)
row_limit = self._process_row_limit(row_limit, result_type)
from_dttm, to_dttm = self._get_dttms(time_range, time_shift, processed_extras)
kwargs["from_dttm"] = from_dttm
kwargs["to_dttm"] = to_dttm
return QueryObject(
datasource=datasource_model_instance,
extras=extras,
row_limit=row_limit,
time_range=time_range,
time_shift=time_shift,
**kwargs,
)

def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)

def _process_extras(self, extras: Optional[Dict[str, Any]]) -> Dict[str, Any]:
extras = extras or {}
if config["SIP_15_ENABLED"]:
extras["time_range_endpoints"] = get_time_range_endpoints(form_data=extras)
return extras

def _process_row_limit(
self, row_limit: Optional[int], result_type: ChartDataResultType
) -> int:
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
return apply_max_row_limit(row_limit or default_row_limit)

def _get_dttms(
self,
time_range: Optional[str],
time_shift: Optional[str],
extras: Dict[str, Any],
) -> Tuple[Optional[datetime], Optional[datetime]]:
return get_since_until(
relative_start=extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)
1 change: 1 addition & 0 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_sql_text(payload: Dict[str, Any]) -> str:


class TestQueryContext(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_schema_deserialization(self):
"""
Ensure that the deserialized QueryContext contains all required fields.
Expand Down