Skip to content

Commit 5b690f9

Browse files
authored
chore: refactor, add typing and fix uncovered errors (apache#8900)
* Add type annotations and fix inconsistencies * Address review comments * Remove incorrect typing of jsonable obj
1 parent 191aca1 commit 5b690f9

22 files changed

+137
-109
lines changed

superset/app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def configure_middlewares(self):
221221

222222
if self.config["ENABLE_CHUNK_ENCODING"]:
223223

224-
class ChunkedEncodingFix(object): # pylint: disable=too-few-public-methods
224+
class ChunkedEncodingFix: # pylint: disable=too-few-public-methods
225225
def __init__(self, app):
226226
self.app = app
227227

superset/common/query_context.py

+14-17
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
import pickle as pkl
1919
from datetime import datetime, timedelta
20-
from typing import Any, Dict, List, Optional
20+
from typing import Any, ClassVar, Dict, List, Optional
2121

2222
import numpy as np
2323
import pandas as pd
@@ -41,8 +41,8 @@ class QueryContext:
4141
to retrieve the data payload for a given viz.
4242
"""
4343

44-
cache_type: str = "df"
45-
enforce_numerical_metrics: bool = True
44+
cache_type: ClassVar[str] = "df"
45+
enforce_numerical_metrics: ClassVar[bool] = True
4646

4747
datasource: BaseDatasource
4848
queries: List[QueryObject]
@@ -53,20 +53,16 @@ class QueryContext:
5353
# a vanilla python type https://github.com/python/mypy/issues/5288
5454
def __init__(
5555
self,
56-
datasource: Dict,
57-
queries: List[Dict],
56+
datasource: Dict[str, Any],
57+
queries: List[Dict[str, Any]],
5858
force: bool = False,
5959
custom_cache_timeout: Optional[int] = None,
6060
) -> None:
61-
self.datasource = ConnectorRegistry.get_datasource( # type: ignore
62-
datasource.get("type"), # type: ignore
63-
int(datasource.get("id")), # type: ignore
64-
db.session,
61+
self.datasource = ConnectorRegistry.get_datasource(
62+
str(datasource["type"]), int(datasource["id"]), db.session
6563
)
66-
self.queries = list(map(lambda query_obj: QueryObject(**query_obj), queries))
67-
64+
self.queries = [QueryObject(**query_obj) for query_obj in queries]
6865
self.force = force
69-
7066
self.custom_cache_timeout = custom_cache_timeout
7167

7268
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
@@ -78,7 +74,7 @@ def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
7874

7975
timestamp_format = None
8076
if self.datasource.type == "table":
81-
dttm_col = self.datasource.get_col(query_object.granularity)
77+
dttm_col = self.datasource.get_column(query_object.granularity)
8278
if dttm_col:
8379
timestamp_format = dttm_col.python_date_format
8480

@@ -115,17 +111,18 @@ def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
115111
"df": df,
116112
}
117113

114+
@staticmethod
118115
def df_metrics_to_num( # pylint: disable=invalid-name,no-self-use
119-
self, df: pd.DataFrame, query_object: QueryObject
116+
df: pd.DataFrame, query_object: QueryObject
120117
) -> None:
121118
"""Converting metrics to numeric when pandas.read_sql cannot"""
122-
metrics = [metric for metric in query_object.metrics]
123119
for col, dtype in df.dtypes.items():
124-
if dtype.type == np.object_ and col in metrics:
120+
if dtype.type == np.object_ and col in query_object.metrics:
125121
df[col] = pd.to_numeric(df[col], errors="coerce")
126122

123+
@staticmethod
127124
def get_data( # pylint: disable=invalid-name,no-self-use
128-
self, df: pd.DataFrame
125+
df: pd.DataFrame
129126
) -> List[Dict]:
130127
return df.to_dict(orient="records")
131128

superset/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def _try_json_readsha(filepath, length): # pylint: disable=unused-argument
449449
# http://docs.celeryproject.org/en/latest/getting-started/brokers/index.html
450450

451451

452-
class CeleryConfig(object): # pylint: disable=too-few-public-methods
452+
class CeleryConfig: # pylint: disable=too-few-public-methods
453453
BROKER_URL = "sqla+sqlite:///celerydb.sqlite"
454454
CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks")
455455
CELERY_RESULT_BACKEND = "db+sqlite:///celery_results.sqlite"

superset/connectors/base/models.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import json
18-
from typing import Any, Dict, List, Optional, Type
18+
from typing import Any, Dict, Hashable, List, Optional, Type
1919

2020
from flask_appbuilder.security.sqla.models import User
2121
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
@@ -44,7 +44,7 @@ class BaseDatasource(
4444
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
4545
column_class: Optional[Type] = None # link to derivative of BaseColumn
4646
metric_class: Optional[Type] = None # link to derivative of BaseMetric
47-
owner_class = None
47+
owner_class: Optional[User] = None
4848

4949
# Used to do code highlighting when displaying the query in the UI
5050
query_language: Optional[str] = None
@@ -342,11 +342,14 @@ def update_from_object(self, obj) -> None:
342342
obj.get("columns"), self.columns, self.column_class, "column_name"
343343
)
344344

345-
def get_extra_cache_keys( # pylint: disable=unused-argument,no-self-use
346-
self, query_obj: Dict
347-
) -> List[Any]:
345+
def get_extra_cache_keys( # pylint: disable=no-self-use
346+
self, query_obj: Dict[str, Any] # pylint: disable=unused-argument
347+
) -> List[Hashable]:
348348
""" If a datasource needs to provide additional keys for calculation of
349349
cache keys, those can be provided via this method
350+
351+
:param query_obj: The dict representation of a query object
352+
:return: list of keys
350353
"""
351354
return []
352355

@@ -403,6 +406,10 @@ def is_string(self) -> bool:
403406
def expression(self):
404407
raise NotImplementedError()
405408

409+
@property
410+
def python_date_format(self):
411+
raise NotImplementedError()
412+
406413
@property
407414
def data(self) -> Dict[str, Any]:
408415
attrs = (
@@ -415,7 +422,6 @@ def data(self) -> Dict[str, Any]:
415422
"groupby",
416423
"is_dttm",
417424
"type",
418-
"python_date_format",
419425
)
420426
return {s: getattr(self, s) for s in attrs if hasattr(self, s)}
421427

superset/connectors/connector_registry.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from superset.connectors.base.models import BaseDatasource
2727

2828

29-
class ConnectorRegistry(object):
29+
class ConnectorRegistry:
3030
""" Central Registry for all available datasource engines"""
3131

3232
sources: Dict[str, Type["BaseDatasource"]] = {}
@@ -43,11 +43,11 @@ def register_sources(cls, datasource_config: OrderedDict) -> None:
4343
@classmethod
4444
def get_datasource(
4545
cls, datasource_type: str, datasource_id: int, session: Session
46-
) -> Optional["BaseDatasource"]:
46+
) -> "BaseDatasource":
4747
return (
4848
session.query(cls.sources[datasource_type])
4949
.filter_by(id=datasource_id)
50-
.first()
50+
.one()
5151
)
5252

5353
@classmethod

superset/connectors/sqla/models.py

+35-24
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020
from collections import OrderedDict
2121
from datetime import datetime
22-
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
22+
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union
2323

2424
import pandas as pd
2525
import sqlalchemy as sa
@@ -84,7 +84,7 @@ class AnnotationDatasource(BaseDatasource):
8484

8585
cache_timeout = 0
8686

87-
def query(self, query_obj: Dict) -> QueryResult:
87+
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
8888
df = None
8989
error_message = None
9090
qry = db.session.query(Annotation)
@@ -537,16 +537,9 @@ def select_star(self) -> str:
537537
latest_partition=False,
538538
)
539539

540-
def get_col(self, col_name: str) -> Optional[Column]:
541-
columns = self.columns
542-
for col in columns:
543-
if col_name == col.column_name:
544-
return col
545-
return None
546-
547540
@property
548541
def data(self) -> Dict:
549-
d = super(SqlaTable, self).data
542+
d = super().data
550543
if self.type == "table":
551544
grains = self.database.grains() or []
552545
if grains:
@@ -598,7 +591,7 @@ def mutate_query_from_config(self, sql: str) -> str:
598591
def get_template_processor(self, **kwargs):
599592
return get_template_processor(table=self, database=self.database, **kwargs)
600593

601-
def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended:
594+
def get_query_str_extended(self, query_obj: Dict[str, Any]) -> QueryStringExtended:
602595
sqlaq = self.get_sqla_query(**query_obj)
603596
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
604597
logging.info(sql)
@@ -608,7 +601,7 @@ def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended:
608601
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
609602
)
610603

611-
def get_query_str(self, query_obj: Dict) -> str:
604+
def get_query_str(self, query_obj: Dict[str, Any]) -> str:
612605
query_str_ext = self.get_query_str_extended(query_obj)
613606
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
614607
return ";\n\n".join(all_queries) + ";"
@@ -976,14 +969,23 @@ def _get_top_groups(
976969

977970
return or_(*groups)
978971

979-
def query(self, query_obj: Dict) -> QueryResult:
972+
def query(self, query_obj: Dict[str, Any]) -> QueryResult:
980973
qry_start_dttm = datetime.now()
981974
query_str_ext = self.get_query_str_extended(query_obj)
982975
sql = query_str_ext.sql
983976
status = utils.QueryStatus.SUCCESS
984977
error_message = None
985978

986-
def mutator(df):
979+
def mutator(df: pd.DataFrame) -> None:
980+
"""
981+
Some engines change the case or generate bespoke column names, either by
982+
default or due to lack of support for aliasing. This function ensures that
983+
the column names in the DataFrame correspond to what is expected by
984+
the viz components.
985+
986+
:param df: Original DataFrame returned by the engine
987+
"""
988+
987989
labels_expected = query_str_ext.labels_expected
988990
if df is not None and not df.empty:
989991
if len(df.columns) != len(labels_expected):
@@ -993,7 +995,6 @@ def mutator(df):
993995
)
994996
else:
995997
df.columns = labels_expected
996-
return df
997998

998999
try:
9991000
df = self.database.get_df(sql, self.schema, mutator)
@@ -1135,13 +1136,16 @@ def query_datasources_by_name(
11351136
def default_query(qry) -> Query:
11361137
return qry.filter_by(is_sqllab_view=False)
11371138

1138-
def has_extra_cache_keys(self, query_obj: Dict) -> bool:
1139+
def has_calls_to_cache_key_wrapper(self, query_obj: Dict[str, Any]) -> bool:
11391140
"""
1140-
Detects the presence of calls to cache_key_wrapper in items in query_obj that can
1141-
be templated.
1141+
Detects the presence of calls to `cache_key_wrapper` in items in query_obj that
1142+
can be templated. If any are present, the query must be evaluated to extract
1143+
additional keys for the cache key. This method is needed to avoid executing
1144+
the template code unnecessarily, as it may contain expensive calls, e.g. to
1145+
extract the latest partition of a database.
11421146
11431147
:param query_obj: query object to analyze
1144-
:return: True if at least one item calls cache_key_wrapper, otherwise False
1148+
:return: True if at least one item calls `cache_key_wrapper`, otherwise False
11451149
"""
11461150
regex = re.compile(r"\{\{.*cache_key_wrapper\(.*\).*\}\}")
11471151
templatable_statements: List[str] = []
@@ -1159,12 +1163,19 @@ def has_extra_cache_keys(self, query_obj: Dict) -> bool:
11591163
return True
11601164
return False
11611165

1162-
def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]:
1163-
if self.has_extra_cache_keys(query_obj):
1166+
def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]:
1167+
"""
1168+
The cache key of a SqlaTable needs to consider any keys added by the parent class
1169+
and any keys added via `cache_key_wrapper`.
1170+
1171+
:param query_obj: query object to analyze
1172+
:return: True if at least one item calls `cache_key_wrapper`, otherwise False
1173+
"""
1174+
extra_cache_keys = super().get_extra_cache_keys(query_obj)
1175+
if self.has_calls_to_cache_key_wrapper(query_obj):
11641176
sqla_query = self.get_sqla_query(**query_obj)
1165-
extra_cache_keys = sqla_query.extra_cache_keys
1166-
return extra_cache_keys
1167-
return []
1177+
extra_cache_keys += sqla_query.extra_cache_keys
1178+
return extra_cache_keys
11681179

11691180

11701181
sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)

superset/dataframe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def is_numeric(dtype):
6868
return np.issubdtype(dtype, np.number)
6969

7070

71-
class SupersetDataFrame(object):
71+
class SupersetDataFrame:
7272
# Mapping numpy dtype.char to generic database types
7373
type_map = {
7474
"b": "BOOL", # boolean

superset/db_engine_specs/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def compile_timegrain_expression(
104104
return element.name.replace("{col}", compiler.process(element.col, **kw))
105105

106106

107-
class LimitMethod(object): # pylint: disable=too-few-public-methods
107+
class LimitMethod: # pylint: disable=too-few-public-methods
108108
"""Enum the ways that limits can be applied"""
109109

110110
FETCH_MANY = "fetch_many"

superset/migrations/versions/258b5280a45e_form_strip_leading_and_trailing_whitespace.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,20 @@
3333
Base = declarative_base()
3434

3535

36-
class BaseColumnMixin(object):
36+
class BaseColumnMixin:
3737
id = Column(Integer, primary_key=True)
3838
column_name = Column(String(255))
3939
description = Column(Text)
4040
type = Column(String(32))
4141
verbose_name = Column(String(1024))
4242

4343

44-
class BaseDatasourceMixin(object):
44+
class BaseDatasourceMixin:
4545
id = Column(Integer, primary_key=True)
4646
description = Column(Text)
4747

4848

49-
class BaseMetricMixin(object):
49+
class BaseMetricMixin:
5050
id = Column(Integer, primary_key=True)
5151
d3format = Column(String(128))
5252
description = Column(Text)

superset/migrations/versions/c617da68de7d_form_nullable.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,20 @@
3636
Base = declarative_base()
3737

3838

39-
class BaseColumnMixin(object):
39+
class BaseColumnMixin:
4040
id = Column(Integer, primary_key=True)
4141
column_name = Column(String(255))
4242
description = Column(Text)
4343
type = Column(String(32))
4444
verbose_name = Column(String(1024))
4545

4646

47-
class BaseDatasourceMixin(object):
47+
class BaseDatasourceMixin:
4848
id = Column(Integer, primary_key=True)
4949
description = Column(Text)
5050

5151

52-
class BaseMetricMixin(object):
52+
class BaseMetricMixin:
5353
id = Column(Integer, primary_key=True)
5454
d3format = Column(String(128))
5555
description = Column(Text)

superset/migrations/versions/d94d33dbe938_form_strip.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,20 @@
3636
Base = declarative_base()
3737

3838

39-
class BaseColumnMixin(object):
39+
class BaseColumnMixin:
4040
id = Column(Integer, primary_key=True)
4141
column_name = Column(String(255))
4242
description = Column(Text)
4343
type = Column(String(32))
4444
verbose_name = Column(String(1024))
4545

4646

47-
class BaseDatasourceMixin(object):
47+
class BaseDatasourceMixin:
4848
id = Column(Integer, primary_key=True)
4949
description = Column(Text)
5050

5151

52-
class BaseMetricMixin(object):
52+
class BaseMetricMixin:
5353
id = Column(Integer, primary_key=True)
5454
d3format = Column(String(128))
5555
description = Column(Text)

0 commit comments

Comments
 (0)