Skip to content

Commit

Permalink
chore(db_engine_specs): clean up column spec logic and add tests (#22871
Browse files Browse the repository at this point in the history
)
  • Loading branch information
villebro authored Jan 31, 2023
1 parent 8466eec commit cd6fc35
Show file tree
Hide file tree
Showing 73 changed files with 1,946 additions and 1,456 deletions.
9 changes: 5 additions & 4 deletions superset/db_engine_specs/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from typing import Any, Dict, Optional, Pattern, Tuple

from flask_babel import gettext as __
from sqlalchemy import types

from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.utils import core as utils

SYNTAX_ERROR_REGEX = re.compile(
": mismatched input '(?P<syntax_error>.*?)'. Expecting: "
Expand Down Expand Up @@ -66,10 +66,11 @@ class AthenaEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"DATE '{dttm.date().isoformat()}'"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
datetime_formatted = dttm.isoformat(sep=" ", timespec="milliseconds")
return f"""TIMESTAMP '{datetime_formatted}'"""
return None
Expand Down
52 changes: 37 additions & 15 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods

_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
_default_column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
Expand Down Expand Up @@ -314,6 +314,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
GenericDataType.BOOLEAN,
),
)
# engine-specific type mappings to check prior to the defaults
column_type_mappings: Tuple[ColumnTypeMapping, ...] = ()

# Does database support join-free timeslot grouping
time_groupby_inline = False
Expand Down Expand Up @@ -1389,24 +1391,25 @@ def make_label_compatible(cls, label: str) -> Union[str, quoted_name]:
return label_mutated

@classmethod
def get_sqla_column_type(
def get_column_types(
cls,
column_type: Optional[str],
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[Tuple[TypeEngine, GenericDataType]]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Override `column_type_mappings` for specific needs
Return a sqlalchemy native column type and generic data type that corresponds
to the column type defined in the data source (return None to use default type
inferred by SQLAlchemy). Override `column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param column_type: Column type returned by inspector
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: SqlAlchemy column type
:return: SQLAlchemy and generic Superset column types
"""
if not column_type:
return None
for regex, sqla_type, generic_type in column_type_mappings:

for regex, sqla_type, generic_type in (
cls.column_type_mappings + cls._default_column_type_mappings
):
match = regex.match(column_type)
if not match:
continue
Expand Down Expand Up @@ -1569,19 +1572,16 @@ def get_column_spec( # pylint: disable=unused-argument
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
"""
Converts native database type to sqlalchemy column type.
Get generic type related specs regarding a native column type.
:param native_type: Native database type
:param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: ColumnSpec object
"""
col_types = cls.get_sqla_column_type(
native_type, column_type_mappings=column_type_mappings
)
col_types = cls.get_column_types(native_type)
if col_types:
column_type, generic_type = col_types
is_dttm = generic_type == GenericDataType.TEMPORAL
Expand All @@ -1590,6 +1590,28 @@ def get_column_spec( # pylint: disable=unused-argument
)
return None

@classmethod
def get_sqla_column_type(
cls,
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
) -> Optional[TypeEngine]:
"""
Converts native database type to sqlalchemy column type.
:param native_type: Native database type
:param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:return: ColumnSpec object
"""
column_spec = cls.get_column_spec(
native_type=native_type,
db_extra=db_extra,
source=source,
)
return column_spec.sqla_type if column_spec else None

# pylint: disable=unused-argument
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
Expand Down
14 changes: 7 additions & 7 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from sqlalchemy import column
from sqlalchemy import column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql import sqltypes
from typing_extensions import TypedDict
Expand Down Expand Up @@ -201,15 +201,15 @@ class BigQueryEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if tt == utils.TemporalType.DATETIME:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
if isinstance(sqla_type, types.DateTime):
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
if tt == utils.TemporalType.TIME:
if isinstance(sqla_type, types.Time):
return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)"""
if tt == utils.TemporalType.TIMESTAMP:
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
return None

@classmethod
Expand Down
9 changes: 5 additions & 4 deletions superset/db_engine_specs/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING

from sqlalchemy import types
from urllib3.exceptions import NewConnectionError

from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.extensions import cache_manager
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
Expand Down Expand Up @@ -77,10 +77,11 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if tt == utils.TemporalType.DATETIME:
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None

Expand Down
12 changes: 8 additions & 4 deletions superset/db_engine_specs/crate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING

from sqlalchemy import types

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
Expand Down Expand Up @@ -53,12 +56,13 @@ def epoch_ms_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.TIMESTAMP:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.TIMESTAMP):
return f"{dttm.timestamp() * 1000}"
return None

@classmethod
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
def alter_new_orm_column(cls, orm_col: TableColumn) -> None:
if orm_col.type == "TIMESTAMP":
orm_col.python_date_format = "epoch_ms"
10 changes: 6 additions & 4 deletions superset/db_engine_specs/dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional

from sqlalchemy import types

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils


class DremioEngineSpec(BaseEngineSpec):
Expand Down Expand Up @@ -46,10 +47,11 @@ def epoch_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
dttm_formatted = dttm.isoformat(sep=" ", timespec="milliseconds")
return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.FFF')"""
return None
9 changes: 5 additions & 4 deletions superset/db_engine_specs/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from typing import Any, Dict, Optional
from urllib import parse

from sqlalchemy import types
from sqlalchemy.engine.url import URL

from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIProgrammingError
from superset.utils import core as utils


class DrillEngineSpec(BaseEngineSpec):
Expand Down Expand Up @@ -59,10 +59,11 @@ def epoch_ms_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'yyyy-MM-dd')"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds")
return f"""TO_TIMESTAMP('{datetime_formatted}', 'yyyy-MM-dd HH:mm:ss')"""
return None
Expand Down
15 changes: 10 additions & 5 deletions superset/db_engine_specs/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING

from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector

from superset import is_feature_enabled
Expand Down Expand Up @@ -70,12 +74,12 @@ class DruidEngineSpec(BaseEngineSpec):
}

@classmethod
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
def alter_new_orm_column(cls, orm_col: TableColumn) -> None:
if orm_col.column_name == "__time":
orm_col.is_dttm = True

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
def get_extra_params(database: Database) -> Dict[str, Any]:
"""
For Druid, the path to a SSL certificate is placed in `connect_args`.
Expand All @@ -102,10 +106,11 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"CAST(TIME_PARSE('{dttm.date().isoformat()}') AS DATE)"
if tt in (utils.TemporalType.DATETIME, utils.TemporalType.TIMESTAMP):
if isinstance(sqla_type, (types.DateTime, types.TIMESTAMP)):
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
return None

Expand Down
7 changes: 4 additions & 3 deletions superset/db_engine_specs/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector

from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
Expand Down Expand Up @@ -67,8 +67,9 @@ def epoch_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME):
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, (types.String, types.DateTime)):
return f"""'{dttm.isoformat(sep=" ", timespec="microseconds")}'"""
return None

Expand Down
9 changes: 6 additions & 3 deletions superset/db_engine_specs/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional

from sqlalchemy import types

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils


class DynamoDBEngineSpec(BaseEngineSpec):
Expand Down Expand Up @@ -56,7 +57,9 @@ def epoch_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME):
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, (types.String, types.DateTime)):
return f"""'{dttm.isoformat(sep=" ", timespec="seconds")}'"""

return None
Loading

0 comments on commit cd6fc35

Please sign in to comment.