Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: QueryContext and QueryObject by decoupling and making it testable #17344

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
factories
  • Loading branch information
ofekisr committed Nov 10, 2021
commit 57b2929472c123c0138e05312d4e67c642bfb63b
29 changes: 22 additions & 7 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from __future__ import annotations
import json
import logging
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, TYPE_CHECKING
from zipfile import ZipFile

import simplejson
Expand Down Expand Up @@ -63,7 +64,7 @@
get_fav_star_ids_schema,
openapi_spec_methods_override,
screenshot_query_schema,
thumbnail_query_schema,
thumbnail_query_schema, ChartDataQueryContextSchema,
)
from superset.commands.importers.exceptions import NoValidFilesFoundError
from superset.commands.importers.v1.utils import get_contents_from_bundle
Expand All @@ -88,12 +89,15 @@
from superset.views.core import CsvResponse, generate_download_headers
from superset.views.filters import FilterRelatedOwners

logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from superset.common.query_factory import QueryContextFactory

logger = logging.getLogger(__name__)

class ChartRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(Slice)

datamodel = SQLAInterface(Slice)
query_context_factory: QueryContextFactory
resource_name = "chart"
allow_browser_login = True

Expand Down Expand Up @@ -619,7 +623,7 @@ def get_data(self, pk: int) -> Response:

try:
command = ChartDataCommand()
query_context = command.set_query_context(json_body)
query_context = command.set_query_context(self._convert_to_dict(json_body))
command.validate()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
Expand Down Expand Up @@ -705,7 +709,9 @@ def data(self) -> Response:

try:
command = ChartDataCommand()
query_context = command.set_query_context(json_body)
raw_query_context = self._convert_to_dict(json_body)
command.set_form_data(json_body)
query_context = command.set_query_context(raw_query_context)
command.validate()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
Expand All @@ -726,6 +732,11 @@ def data(self) -> Response:

return self.get_data_response(command)

def _convert_to_dict(self, json_body):
raw_data = ChartDataQueryContextSchema().load(json_body)
query_context = self.query_context_factory.create_from_dict(raw_data)
return query_context

def _run_async(self, command: ChartDataCommand) -> Response:
"""
Execute command as an async query.
Expand Down Expand Up @@ -797,7 +808,8 @@ def data_from_cache(self, cache_key: str) -> Response:
command = ChartDataCommand()
try:
cached_data = command.load_query_context_from_cache(cache_key)
command.set_query_context(cached_data)
query_context = self._convert_to_dict(cached_data)
command.set_query_context(query_context)
command.validate()
except ChartDataCacheLoadError:
return self.response_404()
Expand Down Expand Up @@ -1202,3 +1214,6 @@ def import_(self) -> Response:
)
command.run()
return self.response(200, message="OK")

def set_query_context_factory(self, query_context_factory: QueryContextFactory):
self.query_context_factory = query_context_factory
19 changes: 8 additions & 11 deletions superset/charts/commands/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@


class ChartDataCommand(BaseCommand):
_query_context: QueryContext
_form_data: Dict[str, Any]

def __init__(self) -> None:
self._form_data: Dict[str, Any]
self._query_context: QueryContext
self._async_channel_id: str
self._query_context_processor = QueryContextProcessor()

Expand Down Expand Up @@ -78,17 +79,13 @@ def run_async(self, user_id: Optional[str]) -> Dict[str, Any]:

return job_metadata

def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext:
self._form_data = form_data
try:
self._query_context = ChartDataQueryContextSchema().load(self._form_data)
except KeyError as ex:
raise ValidationError("Request is incorrect") from ex
except ValidationError as error:
raise error

def set_query_context(self, query_context: QueryContext) -> QueryContext:
self._query_context = query_context
return self._query_context

def set_form_data(self, form_data: Dict[str, Any]):
self._form_data = form_data

def validate(self) -> None:
self._query_context_processor.raise_for_access(self._query_context)

Expand Down
10 changes: 1 addition & 9 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from typing import Any, Dict

from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from marshmallow import EXCLUDE, fields, Schema, validate
from marshmallow.validate import Length, Range
from marshmallow_enum import EnumField

from superset import app
from superset.common.query_context import QueryContext
from superset.common.query_factory import QueryContextFactory
from superset.db_engine_specs.base import builtin_time_grains
from superset.utils import schema as utils
from superset.utils.core import (
Expand Down Expand Up @@ -1137,11 +1134,6 @@ class ChartDataQueryContextSchema(Schema):
result_type = EnumField(ChartDataResultType, by_value=True)
result_format = EnumField(ChartDataResultFormat, by_value=True)

# pylint: disable=no-self-use,unused-argument
@post_load
def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext:
return QueryContextFactory.create_from_dict(data)


class AnnotationDataSchema(Schema):
columns = fields.List(
Expand Down
144 changes: 37 additions & 107 deletions superset/common/query_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,58 @@

import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING

from superset import app, ConnectorRegistry, db
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
from superset.typing import Metric, OrderBy
from superset.common.query_object_factory import (
QueryObjectFactory,
QueryObjectFactoryImpl,
)
from superset.utils.core import (
apply_max_row_limit,
ChartDataResultFormat,
ChartDataResultType,
DatasourceDict,
QueryObjectFilterClause,
)
from superset.utils.date_parser import get_since_until, parse_human_timedelta
from superset.views.utils import get_time_range_endpoints

config = app.config
if TYPE_CHECKING:
from superset import ConnectorRegistry
from superset.common.query_object import QueryObject

logger = logging.getLogger(__name__)


class QueryContextFactory:
@classmethod
query_object_factory: QueryObjectFactory
session_factory: Any
connector_registry: ConnectorRegistry

def __init__(
self,
app_config: Dict[str, Any],
connector_registry: ConnectorRegistry,
session_factory: Any,
):
self.connector_registry = connector_registry
self.session_factory = session_factory
self.query_object_factory = QueryObjectFactoryImpl(
app_config, connector_registry, session_factory
)

def create( # pylint: disable=too-many-arguments
cls,
self,
datasource_dict: DatasourceDict,
queries_dicts: List[Dict[str, Any]],
force: bool = False,
custom_cache_timeout: Optional[int] = None,
result_type: ChartDataResultType = ChartDataResultType.FULL,
result_format: Optional[ChartDataResultFormat] = None,
) -> QueryContext:
datasource = ConnectorRegistry.get_datasource(
str(datasource_dict["type"]), int(datasource_dict["id"]), db.session
datasource = self.connector_registry.get_datasource(
str(datasource_dict["type"]),
int(datasource_dict["id"]),
self.session_factory(),
)
queries = cls.create_queries_object(queries_dicts, result_type)
queries = self.create_queries_object(queries_dicts, result_type)
return QueryContext(
datasource,
queries,
Expand All @@ -64,112 +81,25 @@ def create( # pylint: disable=too-many-arguments
raw_queries=queries_dicts,
)

@classmethod
def create_from_dict(cls, raw_query_context: Dict[str, Any]) -> QueryContext:
def create_from_dict(self, raw_query_context: Dict[str, Any]) -> QueryContext:
if "datasource" in raw_query_context:
raw_query_context["datasource_dict"] = raw_query_context.pop(
"datasource", {}
)
if "queries" in raw_query_context:
raw_query_context["queries_dicts"] = raw_query_context.pop("queries", [])
return cls.create(**raw_query_context)
return self.create(**raw_query_context)

@classmethod
def create_from_json(cls, raw_query_context: str) -> QueryContext:
def create_from_json(self, raw_query_context: str) -> QueryContext:
raw_data = json.loads(raw_query_context)
return cls.create_from_dict(raw_data)
return self.create_from_dict(raw_data)

@staticmethod
def create_queries_object(
queries_dicts: List[Dict[str, Any]], result_type: ChartDataResultType
self, queries_dicts: List[Dict[str, Any]], result_type: ChartDataResultType
) -> List[QueryObject]:
for qd in queries_dicts:
qd.setdefault("result_type", result_type)
queries = [
QueryObjectFactory.create(**query_obj) for query_obj in queries_dicts
self.query_object_factory.create(**query_obj) for query_obj in queries_dicts
]
return queries


class QueryObjectFactory: # pylint: disable=too-few-public-methods
@classmethod
def create( # pylint: disable=too-many-arguments, too-many-locals
cls,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
columns: Optional[List[str]] = None,
datasource_dict: Optional[DatasourceDict] = None,
extras: Optional[Dict[str, Any]] = None,
filters: Optional[List[QueryObjectFilterClause]] = None,
granularity: Optional[str] = None,
is_rowcount: bool = False,
is_timeseries: Optional[bool] = None,
metrics: Optional[List[Metric]] = None,
order_desc: bool = True,
orderby: Optional[List[OrderBy]] = None,
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
result_type: Optional[ChartDataResultType] = None,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
series_columns: Optional[List[str]] = None,
series_limit: int = 0,
series_limit_metric: Optional[Metric] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
**kwargs: Any,
) -> QueryObject:
datasource = None
if datasource_dict:
datasource = ConnectorRegistry.get_datasource(
str(datasource_dict["type"]), int(datasource_dict["id"]), db.session
)
extras = extras or {}
from_dttm, to_dttm = get_since_until(
relative_start=extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)
if row_limit is None:
row_limit = (
config["SAMPLES_ROW_LIMIT"]
if result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)

actual_row_limit = apply_max_row_limit(row_limit)
if config["SIP_15_ENABLED"]:
extras["time_range_endpoints"] = get_time_range_endpoints(form_data=extras)

return QueryObject(
annotation_layers,
applied_time_extras,
apply_fetch_values_predicate,
columns,
datasource,
extras,
filters,
granularity,
is_rowcount,
is_timeseries,
metrics,
order_desc,
orderby,
post_processing,
result_type,
actual_row_limit,
row_offset,
series_columns,
series_limit,
series_limit_metric,
from_dttm,
to_dttm,
parse_human_timedelta(time_shift),
time_range,
**kwargs,
)
9 changes: 7 additions & 2 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
from __future__ import annotations
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING

from flask_babel import gettext as _
from pandas import DataFrame

from superset.connectors.base.models import BaseDatasource

from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing
Expand All @@ -37,6 +38,10 @@
)
from superset.utils.hashing import md5_sha_from_dict

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource


logger = logging.getLogger(__name__)

# TODO: Type Metrics dictionary with TypedDict when it becomes a vanilla python type
Expand Down
Loading