Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(QueryObject): decouple from queryContext and clean code #17465

Merged
merged 1 commit into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class QueryContext:

datasource: BaseDatasource
queries: List[QueryObject]
force: bool
custom_cache_timeout: Optional[int]
result_type: ChartDataResultType
result_format: ChartDataResultFormat
force: bool
custom_cache_timeout: Optional[int]

# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
Expand All @@ -92,19 +92,21 @@ def __init__(
self,
datasource: DatasourceDict,
queries: List[Dict[str, Any]],
force: bool = False,
custom_cache_timeout: Optional[int] = None,
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
force: bool = False,
custom_cache_timeout: Optional[int] = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
self.queries = [
QueryObject(self.result_type, **query_obj) for query_obj in queries
]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.cache_values = {
"datasource": datasource,
"queries": queries,
Expand Down
141 changes: 86 additions & 55 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
from __future__ import annotations

import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
Expand Down Expand Up @@ -106,11 +108,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
series_limit_metric: Optional[Metric]
time_offsets: List[str]
time_shift: Optional[timedelta]
time_range: Optional[str]
to_dttm: Optional[datetime]

def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
query_context: "QueryContext",
parent_result_type: ChartDataResultType,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
Expand All @@ -125,7 +128,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
order_desc: bool = True,
orderby: Optional[List[OrderBy]] = None,
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
result_type: Optional[ChartDataResultType] = None,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
series_columns: Optional[List[Column]] = None,
Expand All @@ -135,88 +137,117 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
time_shift: Optional[str] = None,
**kwargs: Any,
):
columns = columns or []
extras = extras or {}
annotation_layers = annotation_layers or []
self.result_type = kwargs.get("result_type", parent_result_type)
self._set_annotation_layers(annotation_layers)
self.applied_time_extras = applied_time_extras or {}
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.columns = columns or []
self._set_datasource(datasource)
self._set_extras(extras)
self.filter = filters or []
self.granularity = granularity
self.is_rowcount = is_rowcount
self._set_is_timeseries(is_timeseries)
self._set_metrics(metrics)
self.order_desc = order_desc
self.orderby = orderby or []
self._set_post_processing(post_processing)
self._set_row_limit(row_limit)
self.row_offset = row_offset or 0
self._init_series_columns(series_columns, metrics, is_timeseries)
self.series_limit = series_limit
self.series_limit_metric = series_limit_metric
self.set_dttms(time_range, time_shift)
self.time_range = time_range
self.time_shift = parse_human_timedelta(time_shift)
self.time_offsets = kwargs.get("time_offsets", [])
self.inner_from_dttm = kwargs.get("inner_from_dttm")
self.inner_to_dttm = kwargs.get("inner_to_dttm")
if series_columns:
self.series_columns = series_columns
elif is_timeseries and metrics:
self.series_columns = columns
else:
self.series_columns = []
self._rename_deprecated_fields(kwargs)
self._move_deprecated_extra_fields(kwargs)

self.is_rowcount = is_rowcount
self.datasource = None
if datasource:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.result_type = result_type or query_context.result_type
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
def _set_annotation_layers(
self, annotation_layers: Optional[List[Dict[str, Any]]]
) -> None:
self.annotation_layers = [
layer
for layer in annotation_layers
for layer in (annotation_layers or [])
# formula annotations don't affect the payload, hence can be dropped
if layer["annotationType"] != "FORMULA"
]
self.applied_time_extras = applied_time_extras or {}
self.granularity = granularity
self.from_dttm, self.to_dttm = get_since_until(
relative_start=extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)

def _set_datasource(self, datasource: Optional[DatasourceDict]) -> None:
self.datasource = None
if datasource:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)

def _set_extras(self, extras: Optional[Dict[str, Any]]) -> None:
self.extras = extras or {}
if config["SIP_15_ENABLED"]:
self.extras["time_range_endpoints"] = get_time_range_endpoints(
form_data=self.extras
)

def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
self.is_timeseries = (
is_timeseries if is_timeseries is not None else DTTM_ALIAS in columns
is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns
)
self.time_range = time_range
self.time_shift = parse_human_timedelta(time_shift)
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]

def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None:
# Support metric reference/definition in the format of
# 1. 'metric_name' - name of predefined metric
# 2. { label: 'label_name' } - legacy format for a predefined metric
# 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric
def is_str_or_adhoc(metric: Metric) -> bool:
return isinstance(metric, str) or is_adhoc_metric(metric)

self.metrics = metrics and [
x if isinstance(x, str) or is_adhoc_metric(x) else x["label"] # type: ignore
for x in metrics
x if is_str_or_adhoc(x) else x["label"] for x in metrics # type: ignore
]

def _set_post_processing(
self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
) -> None:
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]

def _set_row_limit(self, row_limit: Optional[int]) -> None:
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if self.result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
self.row_offset = row_offset or 0
self.filter = filters or []
self.series_limit = series_limit
self.series_limit_metric = series_limit_metric
self.order_desc = order_desc
self.extras = extras

if config["SIP_15_ENABLED"]:
self.extras["time_range_endpoints"] = get_time_range_endpoints(
form_data=self.extras
)

self.columns = columns
self.orderby = orderby or []
def _init_series_columns(
self,
series_columns: Optional[List[Column]],
metrics: Optional[List[Metric]],
is_timeseries: Optional[bool],
) -> None:
if series_columns:
self.series_columns = series_columns
elif is_timeseries and metrics:
self.series_columns = self.columns
else:
self.series_columns = []

self._rename_deprecated_fields(kwargs)
self._move_deprecated_extra_fields(kwargs)
def set_dttms(self, time_range: Optional[str], time_shift: Optional[str]) -> None:
self.from_dttm, self.to_dttm = get_since_until(
relative_start=self.extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=self.extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)

def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
# rename deprecated fields
Expand Down