19
19
import re
20
20
from collections import OrderedDict
21
21
from datetime import datetime
22
- from typing import Any , Dict , List , NamedTuple , Optional , Tuple , Union
22
+ from typing import Any , Dict , Hashable , List , NamedTuple , Optional , Tuple , Union
23
23
24
24
import pandas as pd
25
25
import sqlalchemy as sa
@@ -84,7 +84,7 @@ class AnnotationDatasource(BaseDatasource):
84
84
85
85
cache_timeout = 0
86
86
87
- def query (self , query_obj : Dict ) -> QueryResult :
87
+ def query (self , query_obj : Dict [ str , Any ] ) -> QueryResult :
88
88
df = None
89
89
error_message = None
90
90
qry = db .session .query (Annotation )
@@ -537,16 +537,9 @@ def select_star(self) -> str:
537
537
latest_partition = False ,
538
538
)
539
539
540
- def get_col (self , col_name : str ) -> Optional [Column ]:
541
- columns = self .columns
542
- for col in columns :
543
- if col_name == col .column_name :
544
- return col
545
- return None
546
-
547
540
@property
548
541
def data (self ) -> Dict :
549
- d = super (SqlaTable , self ).data
542
+ d = super ().data
550
543
if self .type == "table" :
551
544
grains = self .database .grains () or []
552
545
if grains :
@@ -598,7 +591,7 @@ def mutate_query_from_config(self, sql: str) -> str:
598
591
def get_template_processor (self , ** kwargs ):
599
592
return get_template_processor (table = self , database = self .database , ** kwargs )
600
593
601
- def get_query_str_extended (self , query_obj : Dict ) -> QueryStringExtended :
594
+ def get_query_str_extended (self , query_obj : Dict [ str , Any ] ) -> QueryStringExtended :
602
595
sqlaq = self .get_sqla_query (** query_obj )
603
596
sql = self .database .compile_sqla_query (sqlaq .sqla_query )
604
597
logging .info (sql )
@@ -608,7 +601,7 @@ def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended:
608
601
labels_expected = sqlaq .labels_expected , sql = sql , prequeries = sqlaq .prequeries
609
602
)
610
603
611
- def get_query_str (self , query_obj : Dict ) -> str :
604
+ def get_query_str (self , query_obj : Dict [ str , Any ] ) -> str :
612
605
query_str_ext = self .get_query_str_extended (query_obj )
613
606
all_queries = query_str_ext .prequeries + [query_str_ext .sql ]
614
607
return ";\n \n " .join (all_queries ) + ";"
@@ -976,14 +969,23 @@ def _get_top_groups(
976
969
977
970
return or_ (* groups )
978
971
979
- def query (self , query_obj : Dict ) -> QueryResult :
972
+ def query (self , query_obj : Dict [ str , Any ] ) -> QueryResult :
980
973
qry_start_dttm = datetime .now ()
981
974
query_str_ext = self .get_query_str_extended (query_obj )
982
975
sql = query_str_ext .sql
983
976
status = utils .QueryStatus .SUCCESS
984
977
error_message = None
985
978
986
- def mutator (df ):
979
+ def mutator (df : pd .DataFrame ) -> None :
980
+ """
981
+ Some engines change the case or generate bespoke column names, either by
982
+ default or due to lack of support for aliasing. This function ensures that
983
+ the column names in the DataFrame correspond to what is expected by
984
+ the viz components.
985
+
986
+ :param df: Original DataFrame returned by the engine
987
+ """
988
+
987
989
labels_expected = query_str_ext .labels_expected
988
990
if df is not None and not df .empty :
989
991
if len (df .columns ) != len (labels_expected ):
@@ -993,7 +995,6 @@ def mutator(df):
993
995
)
994
996
else :
995
997
df .columns = labels_expected
996
- return df
997
998
998
999
try :
999
1000
df = self .database .get_df (sql , self .schema , mutator )
@@ -1135,13 +1136,16 @@ def query_datasources_by_name(
1135
1136
def default_query (qry ) -> Query :
1136
1137
return qry .filter_by (is_sqllab_view = False )
1137
1138
1138
- def has_extra_cache_keys (self , query_obj : Dict ) -> bool :
1139
+ def has_calls_to_cache_key_wrapper (self , query_obj : Dict [ str , Any ] ) -> bool :
1139
1140
"""
1140
- Detects the presence of calls to cache_key_wrapper in items in query_obj that can
1141
- be templated.
1141
+ Detects the presence of calls to `cache_key_wrapper` in items in query_obj that
1142
+ can be templated. If any are present, the query must be evaluated to extract
1143
+ additional keys for the cache key. This method is needed to avoid executing
1144
+ the template code unnecessarily, as it may contain expensive calls, e.g. to
1145
+ extract the latest partition of a database.
1142
1146
1143
1147
:param query_obj: query object to analyze
1144
- :return: True if at least one item calls cache_key_wrapper, otherwise False
1148
+ :return: True if at least one item calls ` cache_key_wrapper` , otherwise False
1145
1149
"""
1146
1150
regex = re .compile (r"\{\{.*cache_key_wrapper\(.*\).*\}\}" )
1147
1151
templatable_statements : List [str ] = []
@@ -1159,12 +1163,19 @@ def has_extra_cache_keys(self, query_obj: Dict) -> bool:
1159
1163
return True
1160
1164
return False
1161
1165
1162
- def get_extra_cache_keys (self , query_obj : Dict ) -> List [Any ]:
1163
- if self .has_extra_cache_keys (query_obj ):
1166
+ def get_extra_cache_keys (self , query_obj : Dict [str , Any ]) -> List [Hashable ]:
1167
+ """
1168
+ The cache key of a SqlaTable needs to consider any keys added by the parent class
1169
+ and any keys added via `cache_key_wrapper`.
1170
+
1171
+ :param query_obj: query object to analyze
1172
+ :return: True if at least one item calls `cache_key_wrapper`, otherwise False
1173
+ """
1174
+ extra_cache_keys = super ().get_extra_cache_keys (query_obj )
1175
+ if self .has_calls_to_cache_key_wrapper (query_obj ):
1164
1176
sqla_query = self .get_sqla_query (** query_obj )
1165
- extra_cache_keys = sqla_query .extra_cache_keys
1166
- return extra_cache_keys
1167
- return []
1177
+ extra_cache_keys += sqla_query .extra_cache_keys
1178
+ return extra_cache_keys
1168
1179
1169
1180
1170
1181
sa .event .listen (SqlaTable , "after_insert" , security_manager .set_perm )
0 commit comments