Skip to content

Commit

Permalink
style(mypy): Enforcing typing for superset.views (apache#9939)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and John Bodley authored Jun 5, 2020
1 parent 5c4d4f1 commit 63e0188
Show file tree
Hide file tree
Showing 23 changed files with 440 additions and 340 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true

[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*,superset.viz,superset.viz_sip38]
[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,,superset.views.*,superset.viz,superset.viz_sip38]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
18 changes: 15 additions & 3 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,26 @@ class BaseDatasource(
# ---------------------------------------------------------------
__tablename__: Optional[str] = None # {connector_name}_datasource
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
column_class: Optional[Type] = None # link to derivative of BaseColumn
metric_class: Optional[Type] = None # link to derivative of BaseMetric

@property
def column_class(self) -> Type:
# link to derivative of BaseColumn
raise NotImplementedError()

@property
def metric_class(self) -> Type:
# link to derivative of BaseMetric
raise NotImplementedError()

owner_class: Optional[User] = None

# Used to do code highlighting when displaying the query in the UI
query_language: Optional[str] = None

name = None # can be a Column or a property pointing to one
@property
def name(self) -> str:
# can be a Column or a property pointing to one
raise NotImplementedError()

# ---------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def num_cols(self) -> List[str]:
return [c.column_name for c in self.columns if c.is_numeric]

@property
def name(self) -> str: # type: ignore
def name(self) -> str:
return self.datasource_name

@property
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def get_perm(self) -> str:
return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self)

@property
def name(self) -> str: # type: ignore
def name(self) -> str:
if not self.schema:
return self.table_name
return "{}.{}".format(self.schema, self.table_name)
Expand Down
6 changes: 4 additions & 2 deletions superset/sql_validators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from typing import Any, Dict, List, Optional

from superset.models.core import Database


class SQLValidationAnnotation:
"""Represents a single annotation (error/warning) in an SQL querytext"""
Expand All @@ -35,7 +37,7 @@ def __init__(
self.start_column = start_column
self.end_column = end_column

def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
"""Return a dictionary representation of this annotation"""
return {
"line_number": self.line_number,
Expand All @@ -53,7 +55,7 @@ class BaseSQLValidator:

@classmethod
def validate(
cls, sql: str, schema: str, database: Any
cls, sql: str, schema: Optional[str], database: Database
) -> List[SQLValidationAnnotation]:
"""Check that the given SQL querystring is valid for the given engine"""
raise NotImplementedError
2 changes: 1 addition & 1 deletion superset/sql_validators/presto_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def validate_statement(

@classmethod
def validate(
cls, sql: str, schema: str, database: Any
cls, sql: str, schema: Optional[str], database: Database
) -> List[SQLValidationAnnotation]:
"""
Presto supports query-validation queries by running them with a
Expand Down
6 changes: 4 additions & 2 deletions superset/tasks/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,11 @@ def deliver_dashboard(schedule: DashboardEmailSchedule) -> None:
"""
dashboard = schedule.dashboard

dashboard_url = _get_url_path("Superset.dashboard", dashboard_id=dashboard.id)
dashboard_url = _get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
dashboard_url_user_friendly = _get_url_path(
"Superset.dashboard", user_friendly=True, dashboard_id=dashboard.id
"Superset.dashboard", user_friendly=True, dashboard_id_or_slug=dashboard.id
)

# Create a driver, fetch the page, wait for the page to render
Expand Down
8 changes: 5 additions & 3 deletions superset/views/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict

from flask_appbuilder import CompactCRUDMixin
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
Expand All @@ -30,7 +32,7 @@ class StartEndDttmValidator: # pylint: disable=too-few-public-methods
Validates dttm fields.
"""

def __call__(self, form, field):
def __call__(self, form: Dict[str, Any], field: Any) -> None:
if not form["start_dttm"].data and not form["end_dttm"].data:
raise StopValidation(_("annotation start time or end time is required."))
elif (
Expand Down Expand Up @@ -82,13 +84,13 @@ class AnnotationModelView(

validators_columns = {"start_dttm": [StartEndDttmValidator()]}

def pre_add(self, item):
def pre_add(self, item: "AnnotationModelView") -> None:
if not item.start_dttm:
item.start_dttm = item.end_dttm
elif not item.end_dttm:
item.end_dttm = item.start_dttm

def pre_update(self, item):
def pre_update(self, item: "AnnotationModelView") -> None:
self.pre_add(item)


Expand Down
7 changes: 4 additions & 3 deletions superset/views/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from superset.common.query_context import QueryContext
from superset.legacy import update_time_range
from superset.models.slice import Slice
from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import api, BaseSupersetView, handle_api_exception

Expand All @@ -34,13 +35,13 @@ class Api(BaseSupersetView):
@handle_api_exception
@has_access_api
@expose("/v1/query/", methods=["POST"])
def query(self):
def query(self) -> FlaskResponse:
"""
Takes a query_obj constructed in the client and returns payload data response
for the given query_obj.
params: query_context: json_blob
"""
query_context = QueryContext(**json.loads(request.form.get("query_context")))
query_context = QueryContext(**json.loads(request.form["query_context"]))
security_manager.assert_query_context_permission(query_context)
payload_json = query_context.get_payload()
return json.dumps(
Expand All @@ -52,7 +53,7 @@ def query(self):
@handle_api_exception
@has_access_api
@expose("/v1/form_data/", methods=["GET"])
def query_form_data(self):
def query_form_data(self) -> FlaskResponse:
"""
Get the formdata stored in the database for existing slice.
params: slice_id: integer
Expand Down
56 changes: 31 additions & 25 deletions superset/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import logging
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union

import dataclasses
import simplejson as json
import yaml
from flask import abort, flash, g, get_flashed_messages, redirect, Response, session
from flask_appbuilder import BaseView, ModelView
from flask_appbuilder import BaseView, Model, ModelView
from flask_appbuilder.actions import action
from flask_appbuilder.forms import DynamicForm
from flask_appbuilder.models.sqla.filters import BaseFilter
Expand All @@ -33,7 +33,9 @@
from flask_babel import get_locale, gettext as __, lazy_gettext as _
from flask_wtf.form import FlaskForm
from sqlalchemy import or_
from sqlalchemy.orm import Query
from werkzeug.exceptions import HTTPException
from wtforms import Form
from wtforms.fields.core import Field, UnboundField

from superset import (
Expand All @@ -47,6 +49,7 @@
from superset.connectors.sqla import models
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.helpers import ImportMixin
from superset.translations.utils import get_language_pack
from superset.typing import FlaskResponse
from superset.utils import core as utils
Expand Down Expand Up @@ -93,7 +96,7 @@ def json_error_response(
status: int = 500,
payload: Optional[Dict[str, Any]] = None,
link: Optional[str] = None,
) -> Response:
) -> FlaskResponse:
if not payload:
payload = {"error": "{}".format(msg)}
if link:
Expand All @@ -110,7 +113,7 @@ def json_errors_response(
errors: List[SupersetError],
status: int = 500,
payload: Optional[Dict[str, Any]] = None,
) -> Response:
) -> FlaskResponse:
if not payload:
payload = {}

Expand All @@ -122,11 +125,11 @@ def json_errors_response(
)


def json_success(json_msg: str, status: int = 200) -> Response:
def json_success(json_msg: str, status: int = 200) -> FlaskResponse:
return Response(json_msg, status=status, mimetype="application/json")


def data_payload_response(payload_json: str, has_error: bool = False) -> Response:
def data_payload_response(payload_json: str, has_error: bool = False) -> FlaskResponse:
status = 400 if has_error else 200
return json_success(payload_json, status=status)

Expand All @@ -140,13 +143,13 @@ def generate_download_headers(
return headers


def api(f):
def api(f: Callable) -> Callable:
"""
A decorator to label an endpoint as an API. Catches uncaught exceptions and
return the response in the JSON format
"""

def wraps(self, *args, **kwargs):
def wraps(self: "BaseSupersetView", *args: Any, **kwargs: Any) -> FlaskResponse:
try:
return f(self, *args, **kwargs)
except Exception as ex: # pylint: disable=broad-except
Expand All @@ -156,14 +159,16 @@ def wraps(self, *args, **kwargs):
return functools.update_wrapper(wraps, f)


def handle_api_exception(f):
def handle_api_exception(
f: Callable[..., FlaskResponse]
) -> Callable[..., FlaskResponse]:
"""
A decorator to catch superset exceptions. Use it after the @api decorator above
so superset exception handler is triggered before the handler for generic
exceptions.
"""

def wraps(self, *args, **kwargs):
def wraps(self: "BaseSupersetView", *args: Any, **kwargs: Any) -> FlaskResponse:
try:
return f(self, *args, **kwargs)
except SupersetSecurityException as ex:
Expand All @@ -179,7 +184,7 @@ def wraps(self, *args, **kwargs):
except HTTPException as ex:
logger.exception(ex)
return json_error_response(
utils.error_msg_from_exception(ex), status=ex.code
utils.error_msg_from_exception(ex), status=cast(int, ex.code)
)
except Exception as ex: # pylint: disable=broad-except
logger.exception(ex)
Expand Down Expand Up @@ -233,15 +238,17 @@ def get_user_roles() -> List[Role]:

class BaseSupersetView(BaseView):
@staticmethod
def json_response(obj, status=200) -> Response: # pylint: disable=no-self-use
def json_response(
obj: Any, status: int = 200
) -> FlaskResponse: # pylint: disable=no-self-use
return Response(
json.dumps(obj, default=utils.json_int_dttm_ser, ignore_nan=True),
status=status,
mimetype="application/json",
)


def menu_data():
def menu_data() -> Dict[str, Any]:
menu = appbuilder.menu.get_data()
root_path = "#"
logo_target_path = ""
Expand Down Expand Up @@ -290,7 +297,7 @@ def menu_data():
}


def common_bootstrap_payload():
def common_bootstrap_payload() -> Dict[str, Any]:
"""Common data always sent to the client"""
messages = get_flashed_messages(with_categories=True)
locale = str(get_locale())
Expand Down Expand Up @@ -335,7 +342,7 @@ class ListWidgetWithCheckboxes(ListWidget): # pylint: disable=too-few-public-me
template = "superset/fab_overrides/list_with_checkboxes.html"


def validate_json(_form, field):
def validate_json(form: Form, field: Field) -> None: # pylint: disable=unused-argument
try:
json.loads(field.data)
except Exception as ex:
Expand All @@ -352,24 +359,23 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods
yaml_dict_key: Optional[str] = None

@action("yaml_export", __("Export to YAML"), __("Export to YAML?"), "fa-download")
def yaml_export(self, items):
def yaml_export(
self, items: Union[ImportMixin, List[ImportMixin]]
) -> FlaskResponse:
if not isinstance(items, list):
items = [items]

data = [t.export_to_dict() for t in items]
if self.yaml_dict_key:
data = {self.yaml_dict_key: data}

return Response(
yaml.safe_dump(data),
yaml.safe_dump({self.yaml_dict_key: data} if self.yaml_dict_key else data),
headers=generate_download_headers("yaml"),
mimetype="application/text",
)


class DeleteMixin: # pylint: disable=too-few-public-methods
def _delete(
self: Union[BaseView, "DeleteMixin", "DruidClusterModelView"], primary_key: int,
) -> None:
def _delete(self: BaseView, primary_key: int,) -> None:
"""
Delete function logic, override to implement diferent logic
deletes the record with primary_key = primary_key
Expand Down Expand Up @@ -411,7 +417,7 @@ def _delete(
@action(
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
)
def muldelete(self, items):
def muldelete(self: BaseView, items: List[Model]) -> FlaskResponse:
if not items:
abort(404)
for item in items:
Expand All @@ -426,7 +432,7 @@ def muldelete(self, items):


class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query, value):
def apply(self, query: Query, value: Any) -> Query:
if security_manager.all_datasource_access():
return query
datasource_perms = security_manager.user_view_menu_names("datasource_access")
Expand Down Expand Up @@ -497,7 +503,7 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:


def bind_field(
_, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
_: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
) -> Field:
"""
Customize how fields are bound by stripping all whitespace.
Expand Down
Loading

0 comments on commit 63e0188

Please sign in to comment.