Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(db_engine_specs): clean up column spec logic and add tests #22871

Merged
merged 20 commits into from
Jan 31, 2023
Merged
9 changes: 5 additions & 4 deletions superset/db_engine_specs/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from typing import Any, Dict, Optional, Pattern, Tuple

from flask_babel import gettext as __
from sqlalchemy import types

from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.utils import core as utils

SYNTAX_ERROR_REGEX = re.compile(
": mismatched input '(?P<syntax_error>.*?)'. Expecting: "
Expand Down Expand Up @@ -66,10 +66,11 @@ class AthenaEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"DATE '{dttm.date().isoformat()}'"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
datetime_formatted = dttm.isoformat(sep=" ", timespec="milliseconds")
return f"""TIMESTAMP '{datetime_formatted}'"""
return None
Expand Down
52 changes: 37 additions & 15 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods

_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
_default_column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
Expand Down Expand Up @@ -314,6 +314,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
GenericDataType.BOOLEAN,
),
)
# engine-specific type mappings to check prior to the defaults
column_type_mappings: Tuple[ColumnTypeMapping, ...] = ()

# Does database support join-free timeslot grouping
time_groupby_inline = False
Expand Down Expand Up @@ -1389,24 +1391,25 @@ def make_label_compatible(cls, label: str) -> Union[str, quoted_name]:
return label_mutated

@classmethod
def get_sqla_column_type(
def get_column_types(
cls,
column_type: Optional[str],
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[Tuple[TypeEngine, GenericDataType]]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Override `column_type_mappings` for specific needs
Return a sqlalchemy native column type and generic data type that corresponds
to the column type defined in the data source (return None to use default type
inferred by SQLAlchemy). Override `column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).

:param column_type: Column type returned by inspector
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: SqlAlchemy column type
:return: SQLAlchemy and generic Superset column types
"""
if not column_type:
return None
for regex, sqla_type, generic_type in column_type_mappings:

for regex, sqla_type, generic_type in (
cls.column_type_mappings + cls._default_column_type_mappings
):
match = regex.match(column_type)
if not match:
continue
Expand Down Expand Up @@ -1569,19 +1572,16 @@ def get_column_spec( # pylint: disable=unused-argument
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
"""
Converts native database type to sqlalchemy column type.
Get generic type related specs regarding a native column type.

:param native_type: Native database type
:param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: ColumnSpec object
"""
col_types = cls.get_sqla_column_type(
native_type, column_type_mappings=column_type_mappings
)
col_types = cls.get_column_types(native_type)
if col_types:
column_type, generic_type = col_types
is_dttm = generic_type == GenericDataType.TEMPORAL
Expand All @@ -1590,6 +1590,28 @@ def get_column_spec( # pylint: disable=unused-argument
)
return None

@classmethod
def get_sqla_column_type(
cls,
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
) -> Optional[TypeEngine]:
"""
Converts native database type to sqlalchemy column type.

:param native_type: Native database type
:param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:return: ColumnSpec object
"""
column_spec = cls.get_column_spec(
native_type=native_type,
db_extra=db_extra,
source=source,
)
return column_spec.sqla_type if column_spec else None

# pylint: disable=unused-argument
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
Expand Down
14 changes: 7 additions & 7 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from sqlalchemy import column
from sqlalchemy import column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql import sqltypes
from typing_extensions import TypedDict
Expand Down Expand Up @@ -201,15 +201,15 @@ class BigQueryEngineSpec(BaseEngineSpec):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"CAST('{dttm.date().isoformat()}' AS DATE)"
if tt == utils.TemporalType.DATETIME:
if isinstance(sqla_type, types.TIMESTAMP):
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
if isinstance(sqla_type, types.DateTime):
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)"""
if tt == utils.TemporalType.TIME:
if isinstance(sqla_type, types.Time):
return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)"""
if tt == utils.TemporalType.TIMESTAMP:
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
return None

@classmethod
Expand Down
9 changes: 5 additions & 4 deletions superset/db_engine_specs/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING

from sqlalchemy import types
from urllib3.exceptions import NewConnectionError

from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.extensions import cache_manager
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
Expand Down Expand Up @@ -77,10 +77,11 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"toDate('{dttm.date().isoformat()}')"
if tt == utils.TemporalType.DATETIME:
if isinstance(sqla_type, types.DateTime):
return f"""toDateTime('{dttm.isoformat(sep=" ", timespec="seconds")}')"""
return None

Expand Down
12 changes: 8 additions & 4 deletions superset/db_engine_specs/crate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING

from sqlalchemy import types

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
Expand Down Expand Up @@ -53,12 +56,13 @@ def epoch_ms_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.TIMESTAMP:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.TIMESTAMP):
return f"{dttm.timestamp() * 1000}"
return None

@classmethod
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
def alter_new_orm_column(cls, orm_col: TableColumn) -> None:
if orm_col.type == "TIMESTAMP":
orm_col.python_date_format = "epoch_ms"
10 changes: 6 additions & 4 deletions superset/db_engine_specs/dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional

from sqlalchemy import types

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils


class DremioEngineSpec(BaseEngineSpec):
Expand Down Expand Up @@ -46,10 +47,11 @@ def epoch_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'YYYY-MM-DD')"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
dttm_formatted = dttm.isoformat(sep=" ", timespec="milliseconds")
return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.FFF')"""
return None
9 changes: 5 additions & 4 deletions superset/db_engine_specs/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from typing import Any, Dict, Optional
from urllib import parse

from sqlalchemy import types
from sqlalchemy.engine.url import URL

from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIProgrammingError
from superset.utils import core as utils


class DrillEngineSpec(BaseEngineSpec):
Expand Down Expand Up @@ -59,10 +59,11 @@ def epoch_ms_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"TO_DATE('{dttm.date().isoformat()}', 'yyyy-MM-dd')"
if tt == utils.TemporalType.TIMESTAMP:
if isinstance(sqla_type, types.TIMESTAMP):
datetime_formatted = dttm.isoformat(sep=" ", timespec="seconds")
return f"""TO_TIMESTAMP('{datetime_formatted}', 'yyyy-MM-dd HH:mm:ss')"""
return None
Expand Down
15 changes: 10 additions & 5 deletions superset/db_engine_specs/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING

from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector

from superset import is_feature_enabled
Expand Down Expand Up @@ -70,12 +74,12 @@ class DruidEngineSpec(BaseEngineSpec):
}

@classmethod
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
def alter_new_orm_column(cls, orm_col: TableColumn) -> None:
if orm_col.column_name == "__time":
orm_col.is_dttm = True

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
def get_extra_params(database: Database) -> Dict[str, Any]:
"""
For Druid, the path to a SSL certificate is placed in `connect_args`.

Expand All @@ -102,10 +106,11 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, types.Date):
return f"CAST(TIME_PARSE('{dttm.date().isoformat()}') AS DATE)"
if tt in (utils.TemporalType.DATETIME, utils.TemporalType.TIMESTAMP):
if isinstance(sqla_type, (types.DateTime, types.TIMESTAMP)):
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
return None

Expand Down
7 changes: 4 additions & 3 deletions superset/db_engine_specs/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector

from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
Expand Down Expand Up @@ -67,8 +67,9 @@ def epoch_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME):
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, (types.String, types.DateTime)):
return f"""'{dttm.isoformat(sep=" ", timespec="microseconds")}'"""
return None

Expand Down
9 changes: 6 additions & 3 deletions superset/db_engine_specs/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from datetime import datetime
from typing import Any, Dict, Optional

from sqlalchemy import types

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils


class DynamoDBEngineSpec(BaseEngineSpec):
Expand Down Expand Up @@ -56,7 +57,9 @@ def epoch_to_dttm(cls) -> str:
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt in (utils.TemporalType.TEXT, utils.TemporalType.DATETIME):
sqla_type = cls.get_sqla_column_type(target_type)

if isinstance(sqla_type, (types.String, types.DateTime)):
return f"""'{dttm.isoformat(sep=" ", timespec="seconds")}'"""

return None
Loading