Skip to content

Commit

Permalink
refactor(QueryObject): decouple from superset (#17479)
Browse files Browse the repository at this point in the history
* refactor: queryObject - decouple from superset

* refactor: queryObject - decouple from superset
  • Loading branch information
ofekisr authored and AAfghahi committed Jan 10, 2022
1 parent fd6993d commit f3ca3bd
Show file tree
Hide file tree
Showing 11 changed files with 562 additions and 329 deletions.
9 changes: 7 additions & 2 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.query_object import QueryObject, QueryObjectFactory
from superset.common.query_object import QueryObject
from superset.common.query_object_factory import QueryObjectFactory
from superset.common.utils import QueryCacheManager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
Expand Down Expand Up @@ -69,6 +70,10 @@ class CachedTimeOffset(TypedDict):
cache_keys: List[Optional[str]]


def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, ConnectorRegistry(), db.session)


class QueryContext:
"""
The query context contains the query object and additional fields necessary
Expand Down Expand Up @@ -102,7 +107,7 @@ def __init__(
)
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
query_object_factory = QueryObjectFactory()
query_object_factory = create_query_object_factory()
self.queries = [
query_object_factory.create(self.result_type, **query_obj)
for query_obj in queries
Expand Down
81 changes: 3 additions & 78 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-self-use
# pylint: disable=invalid-name
from __future__ import annotations

import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING

from flask_babel import gettext as _
from pandas import DataFrame

from superset import app, db
from superset.common.chart_data import ChartDataResultType
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import QueryObjectValidationError
from superset.typing import Column, Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
apply_max_row_limit,
DatasourceDict,
DTTM_ALIAS,
find_duplicates,
get_column_names,
Expand All @@ -41,15 +37,12 @@
json_int_dttm_ser,
QueryObjectFilterClause,
)
from superset.utils.date_parser import get_since_until, parse_human_timedelta
from superset.utils.date_parser import parse_human_timedelta
from superset.utils.hashing import md5_sha_from_dict
from superset.views.utils import get_time_range_endpoints

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource


config = app.config
logger = logging.getLogger(__name__)

# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
Expand Down Expand Up @@ -404,71 +397,3 @@ def exec_post_processing(self, df: DataFrame) -> DataFrame:
options = post_process.get("options", {})
df = getattr(pandas_postprocessing, operation)(df, **options)
return df


class QueryObjectFactory: # pylint: disable=too-few-public-methods
def create( # pylint: disable=too-many-arguments
self,
parent_result_type: ChartDataResultType,
datasource: Optional[DatasourceDict] = None,
extras: Optional[Dict[str, Any]] = None,
row_limit: Optional[int] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
**kwargs: Any,
) -> QueryObject:
datasource_model_instance = None
if datasource:
datasource_model_instance = self._convert_to_model(datasource)
processed_extras = self._process_extras(extras)
result_type = kwargs.setdefault("result_type", parent_result_type)
row_limit = self._process_row_limit(row_limit, result_type)
from_dttm, to_dttm = self._get_dttms(time_range, time_shift, processed_extras)
kwargs["from_dttm"] = from_dttm
kwargs["to_dttm"] = to_dttm
return QueryObject(
datasource=datasource_model_instance,
extras=extras,
row_limit=row_limit,
time_range=time_range,
time_shift=time_shift,
**kwargs,
)

def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)

def _process_extras(self, extras: Optional[Dict[str, Any]]) -> Dict[str, Any]:
extras = extras or {}
if config["SIP_15_ENABLED"]:
extras["time_range_endpoints"] = get_time_range_endpoints(form_data=extras)
return extras

def _process_row_limit(
self, row_limit: Optional[int], result_type: ChartDataResultType
) -> int:
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
return apply_max_row_limit(row_limit or default_row_limit)

def _get_dttms(
self,
time_range: Optional[str],
time_shift: Optional[str],
extras: Dict[str, Any],
) -> Tuple[Optional[datetime], Optional[datetime]]:
return get_since_until(
relative_start=extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)
134 changes: 134 additions & 0 deletions superset/common/query_object_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from datetime import date, datetime
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING

from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
from superset.utils.core import apply_max_row_limit, DatasourceDict, TimeRangeEndpoint
from superset.utils.date_parser import get_since_until

if TYPE_CHECKING:
from sqlalchemy.orm import sessionmaker

from superset import ConnectorRegistry
from superset.connectors.base.models import BaseDatasource


class QueryObjectFactory: # pylint: disable=too-few-public-methods
_config: Dict[str, Any]
_connector_registry: ConnectorRegistry
_session_maker: sessionmaker

def __init__(
self,
app_configurations: Dict[str, Any],
connector_registry: ConnectorRegistry,
session_maker: sessionmaker,
):
self._config = app_configurations
self._connector_registry = connector_registry
self._session_maker = session_maker

def create( # pylint: disable=too-many-arguments
self,
parent_result_type: ChartDataResultType,
datasource: Optional[DatasourceDict] = None,
extras: Optional[Dict[str, Any]] = None,
row_limit: Optional[int] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
**kwargs: Any,
) -> QueryObject:
datasource_model_instance = None
if datasource:
datasource_model_instance = self._convert_to_model(datasource)
processed_extras = self._process_extras(extras)
result_type = kwargs.setdefault("result_type", parent_result_type)
row_limit = self._process_row_limit(row_limit, result_type)
from_dttm, to_dttm = self._get_dttms(time_range, time_shift, processed_extras)
kwargs["from_dttm"] = from_dttm
kwargs["to_dttm"] = to_dttm
return QueryObject(
datasource=datasource_model_instance,
extras=extras,
row_limit=row_limit,
time_range=time_range,
time_shift=time_shift,
**kwargs,
)

def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return self._connector_registry.get_datasource(
str(datasource["type"]), int(datasource["id"]), self._session_maker()
)

def _process_extras(self, extras: Optional[Dict[str, Any]]) -> Dict[str, Any]:
extras = extras or {}
if self._config["SIP_15_ENABLED"]:
extras["time_range_endpoints"] = self._determine_time_range_endpoints(
extras.get("time_range_endpoints")
)
return extras

def _process_row_limit(
self, row_limit: Optional[int], result_type: ChartDataResultType
) -> int:
default_row_limit = (
self._config["SAMPLES_ROW_LIMIT"]
if result_type == ChartDataResultType.SAMPLES
else self._config["ROW_LIMIT"]
)
return apply_max_row_limit(row_limit or default_row_limit)

def _get_dttms(
self,
time_range: Optional[str],
time_shift: Optional[str],
extras: Dict[str, Any],
) -> Tuple[Optional[datetime], Optional[datetime]]:
return get_since_until(
relative_start=extras.get(
"relative_start", self._config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", self._config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)

# light version of the view.utils.core
# import view.utils require application context
# Todo: move it and the view.utils.core to utils package

def _determine_time_range_endpoints(
self, raw_endpoints: Optional[Tuple[str, str]] = None,
) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]:
if (
self._config["SIP_15_GRACE_PERIOD_END"]
and date.today() >= self._config["SIP_15_GRACE_PERIOD_END"]
):
return TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE

if raw_endpoints:
start, end = raw_endpoints
return TimeRangeEndpoint(start), TimeRangeEndpoint(end)

return TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE
16 changes: 16 additions & 0 deletions tests/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit f3ca3bd

Please sign in to comment.