Skip to content

Commit

Permalink
refactor(QueryContext): add QueryContextFactory (#17495)
Browse files Browse the repository at this point in the history
  • Loading branch information
ofekisr authored Nov 21, 2021
1 parent 261e418 commit 8a6ecd3
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 48 deletions.
19 changes: 14 additions & 5 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from typing import Any, Dict
from __future__ import annotations

from typing import Any, Dict, Optional, TYPE_CHECKING

from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
Expand All @@ -24,7 +26,7 @@

from superset import app
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_context_factory import QueryContextFactory
from superset.db_engine_specs.base import builtin_time_grains
from superset.utils import schema as utils
from superset.utils.core import (
Expand All @@ -35,6 +37,9 @@
TimeRangeEndpoint,
)

if TYPE_CHECKING:
from superset.common.query_context import QueryContext

config = app.config

#
Expand Down Expand Up @@ -1129,6 +1134,7 @@ class Meta: # pylint: disable=too-few-public-methods


class ChartDataQueryContextSchema(Schema):
query_context_factory: Optional[QueryContextFactory] = None
datasource = fields.Nested(ChartDataDatasourceSchema)
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
force = fields.Boolean(
Expand All @@ -1139,13 +1145,16 @@ class ChartDataQueryContextSchema(Schema):
result_type = EnumField(ChartDataResultType, by_value=True)
result_format = EnumField(ChartDataResultFormat, by_value=True)

# pylint: disable=no-self-use,unused-argument
# pylint: disable=unused-argument
@post_load
def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext:
query_context = QueryContext(**data)
query_context = self.get_query_context_factory().create(**data)
return query_context

# pylint: enable=no-self-use,unused-argument
def get_query_context_factory(self) -> QueryContextFactory:
if self.query_context_factory is None:
self.query_context_factory = QueryContextFactory()
return self.query_context_factory


class AnnotationDataSchema(Schema):
Expand Down
45 changes: 15 additions & 30 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,21 @@
from pandas import DateOffset
from typing_extensions import TypedDict

from superset import app, db, is_feature_enabled
from superset import app, is_feature_enabled
from superset.annotation_layers.dao import AnnotationLayerDAO
from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.query_object import QueryObject
from superset.common.query_object_factory import QueryObjectFactory
from superset.common.utils import QueryCacheManager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.constants import CacheRegion
from superset.exceptions import QueryObjectValidationError, SupersetException
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.utils import csv
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import (
DatasourceDict,
DTTM_ALIAS,
error_msg_from_exception,
get_column_names_from_columns,
Expand All @@ -57,6 +53,7 @@
from superset.views.utils import get_viz

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
from superset.stats_logger import BaseStatsLogger

config = app.config
Expand All @@ -70,10 +67,6 @@ class CachedTimeOffset(TypedDict):
cache_keys: List[Optional[str]]


def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, ConnectorRegistry(), db.session)


class QueryContext:
"""
The query context contains the query object and additional fields necessary
Expand All @@ -90,36 +83,28 @@ class QueryContext:
force: bool
custom_cache_timeout: Optional[int]

cache_values: Dict[str, Any]

# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
# pylint: disable=too-many-arguments
def __init__(
self,
datasource: DatasourceDict,
queries: List[Dict[str, Any]],
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
*,
datasource: BaseDatasource,
queries: List[QueryObject],
result_type: ChartDataResultType,
result_format: ChartDataResultFormat,
force: bool = False,
custom_cache_timeout: Optional[int] = None,
cache_values: Dict[str, Any]
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
query_object_factory = create_query_object_factory()
self.queries = [
query_object_factory.create(self.result_type, **query_obj)
for query_obj in queries
]
self.datasource = datasource
self.result_type = result_type
self.result_format = result_format
self.queries = queries
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.cache_values = {
"datasource": datasource,
"queries": queries,
"result_type": self.result_type,
"result_format": self.result_format,
}
self.cache_values = cache_values

@staticmethod
def left_join_df(
Expand Down
83 changes: 83 additions & 0 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Any, Dict, List, Optional, TYPE_CHECKING

from superset import app, db
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object_factory import QueryObjectFactory
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import DatasourceDict

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource

config = app.config


def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, ConnectorRegistry(), db.session)


class QueryContextFactory: # pylint: disable=too-few-public-methods
_query_object_factory: QueryObjectFactory

def __init__(self) -> None:
self._query_object_factory = create_query_object_factory()

def create(
self,
*,
datasource: DatasourceDict,
queries: List[Dict[str, Any]],
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
force: bool = False,
custom_cache_timeout: Optional[int] = None
) -> QueryContext:
datasource_model_instance = None
if datasource:
datasource_model_instance = self._convert_to_model(datasource)
result_type = result_type or ChartDataResultType.FULL
result_format = result_format or ChartDataResultFormat.JSON
queries_ = [
self._query_object_factory.create(result_type, **query_obj)
for query_obj in queries
]
cache_values = {
"datasource": datasource,
"queries": queries,
"result_type": result_type,
"result_format": result_format,
}
return QueryContext(
datasource=datasource_model_instance,
queries=queries_,
result_type=result_type,
result_format=result_format,
force=force,
custom_cache_timeout=custom_cache_timeout,
cache_values=cache_values,
)

# pylint: disable=no-self-use
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
22 changes: 17 additions & 5 deletions superset/models/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
import logging
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
Expand Down Expand Up @@ -41,6 +43,7 @@

if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.common.query_context_factory import QueryContextFactory
from superset.connectors.base.models import BaseDatasource

metadata = Model.metadata # pylint: disable=no-member
Expand All @@ -59,6 +62,8 @@ class Slice( # pylint: disable=too-many-public-methods
):
"""A slice is essentially a report or a view on data"""

query_context_factory: Optional[QueryContextFactory] = None

__tablename__ = "slices"
id = Column(Integer, primary_key=True)
slice_name = Column(String(250))
Expand Down Expand Up @@ -248,13 +253,12 @@ def form_data(self) -> Dict[str, Any]:
update_time_range(form_data)
return form_data

def get_query_context(self) -> Optional["QueryContext"]:
# pylint: disable=import-outside-toplevel
from superset.common.query_context import QueryContext

def get_query_context(self) -> Optional[QueryContext]:
if self.query_context:
try:
return QueryContext(**json.loads(self.query_context))
return self.get_query_context_factory().create(
**json.loads(self.query_context)
)
except json.decoder.JSONDecodeError as ex:
logger.error("Malformed json in slice's query context", exc_info=True)
logger.exception(ex)
Expand Down Expand Up @@ -313,6 +317,14 @@ def icons(self) -> str:
def url(self) -> str:
return f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"

def get_query_context_factory(self) -> QueryContextFactory:
if self.query_context_factory is None:
# pylint: disable=import-outside-toplevel
from superset.common.query_context_factory import QueryContextFactory

self.query_context_factory = QueryContextFactory()
return self.query_context_factory


def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) -> None:
src_class = target.cls_model
Expand Down
24 changes: 20 additions & 4 deletions superset/views/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from __future__ import annotations

from typing import Any, TYPE_CHECKING

import simplejson as json
from flask import request
Expand All @@ -27,31 +29,37 @@
TimeRangeAmbiguousError,
TimeRangeParseFailError,
)
from superset.common.query_context import QueryContext
from superset.legacy import update_time_range
from superset.models.slice import Slice
from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.utils.date_parser import get_since_until
from superset.views.base import api, BaseSupersetView, handle_api_exception

if TYPE_CHECKING:
from superset.common.query_context_factory import QueryContextFactory

get_time_range_schema = {"type": "string"}


class Api(BaseSupersetView):
query_context_factory = None

@event_logger.log_this
@api
@handle_api_exception
@has_access_api
@expose("/v1/query/", methods=["POST"])
def query(self) -> FlaskResponse: # pylint: disable=no-self-use
def query(self) -> FlaskResponse:
"""
Takes a query_obj constructed in the client and returns payload data response
for the given query_obj.
raises SupersetSecurityException: If the user cannot access the resource
"""
query_context = QueryContext(**json.loads(request.form["query_context"]))
query_context = self.get_query_context_factory().create(
**json.loads(request.form["query_context"])
)
query_context.raise_for_access()
result = query_context.get_payload()
payload_json = result["queries"]
Expand Down Expand Up @@ -99,3 +107,11 @@ def time_range(self, **kwargs: Any) -> FlaskResponse:
except (ValueError, TimeRangeParseFailError, TimeRangeAmbiguousError) as error:
error_msg = {"message": f"Unexpected time range: {error}"}
return self.json_response(error_msg, 400)

def get_query_context_factory(self) -> QueryContextFactory:
if self.query_context_factory is None:
# pylint: disable=import-outside-toplevel
from superset.common.query_context_factory import QueryContextFactory

self.query_context_factory = QueryContextFactory()
return self.query_context_factory
Loading

0 comments on commit 8a6ecd3

Please sign in to comment.