From a1687c800dc081679810e5b16d70fde488cffb1d Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 11 Jul 2024 17:01:37 -0400 Subject: [PATCH] fix: Trino get_columns --- superset/db_engine_specs/base.py | 28 +- superset/db_engine_specs/couchbasedb.py | 1 - superset/db_engine_specs/presto.py | 408 +++++++++--------- superset/db_engine_specs/trino.py | 23 +- tests/unit_tests/db_engine_specs/test_base.py | 23 + .../unit_tests/db_engine_specs/test_trino.py | 63 ++- 6 files changed, 323 insertions(+), 223 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 159c510fe9656..1329597f02b72 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1618,7 +1618,7 @@ def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: ] @classmethod - def select_star( # pylint: disable=too-many-arguments,too-many-locals + def select_star( # pylint: disable=too-many-arguments cls, database: Database, table: Table, @@ -1653,14 +1653,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals if show_cols: fields = cls._get_fields(cols) - quote = engine.dialect.identifier_preparer.quote - quote_schema = engine.dialect.identifier_preparer.quote_schema - full_table_name = ( - quote_schema(table.schema) + "." + quote(table.table) - if table.schema - else quote(table.table) - ) - + full_table_name = cls.quote_table(table, engine.dialect) qry = select(fields).select_from(text(full_table_name)) if limit and cls.allow_limit_clause: @@ -2224,6 +2217,23 @@ def denormalize_name(cls, dialect: Dialect, name: str) -> str: return name + @classmethod + def quote_table(cls, table: Table, dialect: Dialect) -> str: + """ + Fully quote a table name, including the schema and catalog. + """ + quoters = { + "catalog": dialect.identifier_preparer.quote_schema, + "schema": dialect.identifier_preparer.quote_schema, + "table": dialect.identifier_preparer.quote, + } + + return ".".join( + function(getattr(table, key)) + for key, function in quoters.items() + if getattr(table, key) + ) + # schema for adding a database by providing parameters instead of the # full SQLAlchemy URI diff --git a/superset/db_engine_specs/couchbasedb.py b/superset/db_engine_specs/couchbasedb.py index b9cebdba3247d..71dc7276791a1 100644 --- a/superset/db_engine_specs/couchbasedb.py +++ b/superset/db_engine_specs/couchbasedb.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines from __future__ import annotations diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index fbd0eff484474..5a375896c1d46 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -672,6 +672,209 @@ def latest_sub_partition( return "" return df.to_dict()[field_to_return][0] + @classmethod + def _show_columns( + cls, + inspector: Inspector, + table: Table, + ) -> list[ResultRow]: + """ + Show presto column names + :param inspector: object that performs database schema inspection + :param table: table instance + :return: list of column objects + """ + full_table_name = cls.quote_table(table, inspector.engine.dialect) + return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table_name}").fetchall() + + @classmethod + def _create_column_info( + cls, name: str, data_type: types.TypeEngine + ) -> ResultSetColumnType: + """ + Create column info object + :param name: column name + :param data_type: column data type + :return: column info object + """ + return { + "column_name": name, + "name": name, + "type": f"{data_type}", + "is_dttm": None, + "type_generic": None, + } + + @classmethod + def get_columns( + cls, + inspector: Inspector, + table: Table, + options: dict[str, Any] | None = None, + ) -> list[ResultSetColumnType]: + """ + Get columns from a Presto data source. This includes handling row and + array data types + :param inspector: object that performs database schema inspection + :param table: table instance + :param options: Extra configuration options, not used by this backend + :return: a list of results that contain column info + (i.e. column name and data type) + """ + columns = cls._show_columns(inspector, table) + result: list[ResultSetColumnType] = [] + for column in columns: + # parse column if it is a row or array + if is_feature_enabled("PRESTO_EXPAND_DATA") and ( + "array" in column.Type or "row" in column.Type + ): + structural_column_index = len(result) + cls._parse_structural_column(column.Column, column.Type, result) + result[structural_column_index]["nullable"] = getattr( + column, "Null", True + ) + result[structural_column_index]["default"] = None + continue + + # otherwise column is a basic data type + column_spec = cls.get_column_spec(column.Type) + column_type = column_spec.sqla_type if column_spec else None + if column_type is None: + column_type = types.String() + logger.info( + "Did not recognize type %s of column %s", + str(column.Type), + str(column.Column), + ) + column_info = cls._create_column_info(column.Column, column_type) + column_info["nullable"] = getattr(column, "Null", True) + column_info["default"] = None + column_info["column_name"] = column.Column + result.append(column_info) + + return result + + @classmethod + def _parse_structural_column( # pylint: disable=too-many-locals + cls, + parent_column_name: str, + parent_data_type: str, + result: list[ResultSetColumnType], + ) -> None: + """ + Parse a row or array column + :param result: list tracking the results + """ + formatted_parent_column_name = parent_column_name + # Quote the column name if there is a space + if " " in parent_column_name: + formatted_parent_column_name = f'"{parent_column_name}"' + full_data_type = f"{formatted_parent_column_name} {parent_data_type}" + original_result_len = len(result) + # split on open parenthesis ( to get the structural + # data type and its component types + data_types = cls._split_data_type(full_data_type, r"\(") + stack: list[tuple[str, str]] = [] + for data_type in data_types: + # split on closed parenthesis ) to track which component + # types belong to what structural data type + inner_types = cls._split_data_type(data_type, r"\)") + for inner_type in inner_types: + # We have finished parsing multiple structural data types + if not inner_type and stack: + stack.pop() + elif cls._has_nested_data_types(inner_type): + # split on comma , to get individual data types + single_fields = cls._split_data_type(inner_type, ",") + for single_field in single_fields: + single_field = single_field.strip() + # If component type starts with a comma, the first single field + # will be an empty string. Disregard this empty string. + if not single_field: + continue + # split on whitespace to get field name and data type + field_info = cls._split_data_type(single_field, r"\s") + # check if there is a structural data type within + # overall structural data type + column_spec = cls.get_column_spec(field_info[1]) + column_type = column_spec.sqla_type if column_spec else None + if column_type is None: + column_type = types.String() + logger.info( + "Did not recognize type %s of column %s", + field_info[1], + field_info[0], + ) + if field_info[1] == "array" or field_info[1] == "row": + stack.append((field_info[0], field_info[1])) + full_parent_path = cls._get_full_name(stack) + result.append( + cls._create_column_info(full_parent_path, column_type) + ) + else: # otherwise this field is a basic data type + full_parent_path = cls._get_full_name(stack) + column_name = f"{full_parent_path}.{field_info[0]}" + result.append( + cls._create_column_info(column_name, column_type) + ) + # If the component type ends with a structural data type, do not pop + # the stack. We have run across a structural data type within the + # overall structural data type. Otherwise, we have completely parsed + # through the entire structural data type and can move on. + if not (inner_type.endswith("array") or inner_type.endswith("row")): + stack.pop() + # We have an array of row objects (i.e. array(row(...))) + elif inner_type in ("array", "row"): + # Push a dummy object to represent the structural data type + stack.append(("", inner_type)) + # We have an array of a basic data types(i.e. array(varchar)). + elif stack: + # Because it is an array of a basic data type. We have finished + # parsing the structural data type and can move on. + stack.pop() + # Unquote the column name if necessary + if formatted_parent_column_name != parent_column_name: + for index in range(original_result_len, len(result)): + result[index]["column_name"] = result[index]["column_name"].replace( + formatted_parent_column_name, parent_column_name + ) + + @classmethod + def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]: + """ + Split data type based on given delimiter. Do not split the string if the + delimiter is enclosed in quotes + :param data_type: data type + :param delimiter: string separator (i.e. open parenthesis, closed parenthesis, + comma, whitespace) + :return: list of strings after breaking it by the delimiter + """ + return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", data_type) + + @classmethod + def _has_nested_data_types(cls, component_type: str) -> bool: + """ + Check if string contains a data type. We determine if there is a data type by + whitespace or multiple data types by commas + :param component_type: data type + :return: boolean + """ + comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)" + white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)" + return ( + re.search(comma_regex, component_type) is not None + or re.search(white_space_regex, component_type) is not None + ) + + @classmethod + def _get_full_name(cls, names: list[tuple[str, str]]) -> str: + """ + Get the full column name + :param names: list of all individual column names + :return: full column name + """ + return ".".join(column[0] for column in names if column[0]) + class PrestoEngineSpec(PrestoBaseEngineSpec): engine = "presto" @@ -840,211 +1043,6 @@ def get_view_names( results = cursor.fetchall() return {row[0] for row in results} - @classmethod - def _create_column_info( - cls, name: str, data_type: types.TypeEngine - ) -> ResultSetColumnType: - """ - Create column info object - :param name: column name - :param data_type: column data type - :return: column info object - """ - return { - "column_name": name, - "name": name, - "type": f"{data_type}", - "is_dttm": None, - "type_generic": None, - } - - @classmethod - def _get_full_name(cls, names: list[tuple[str, str]]) -> str: - """ - Get the full column name - :param names: list of all individual column names - :return: full column name - """ - return ".".join(column[0] for column in names if column[0]) - - @classmethod - def _has_nested_data_types(cls, component_type: str) -> bool: - """ - Check if string contains a data type. We determine if there is a data type by - whitespace or multiple data types by commas - :param component_type: data type - :return: boolean - """ - comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)" - white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)" - return ( - re.search(comma_regex, component_type) is not None - or re.search(white_space_regex, component_type) is not None - ) - - @classmethod - def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]: - """ - Split data type based on given delimiter. Do not split the string if the - delimiter is enclosed in quotes - :param data_type: data type - :param delimiter: string separator (i.e. open parenthesis, closed parenthesis, - comma, whitespace) - :return: list of strings after breaking it by the delimiter - """ - return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", data_type) - - @classmethod - def _parse_structural_column( # pylint: disable=too-many-locals - cls, - parent_column_name: str, - parent_data_type: str, - result: list[ResultSetColumnType], - ) -> None: - """ - Parse a row or array column - :param result: list tracking the results - """ - formatted_parent_column_name = parent_column_name - # Quote the column name if there is a space - if " " in parent_column_name: - formatted_parent_column_name = f'"{parent_column_name}"' - full_data_type = f"{formatted_parent_column_name} {parent_data_type}" - original_result_len = len(result) - # split on open parenthesis ( to get the structural - # data type and its component types - data_types = cls._split_data_type(full_data_type, r"\(") - stack: list[tuple[str, str]] = [] - for data_type in data_types: - # split on closed parenthesis ) to track which component - # types belong to what structural data type - inner_types = cls._split_data_type(data_type, r"\)") - for inner_type in inner_types: - # We have finished parsing multiple structural data types - if not inner_type and stack: - stack.pop() - elif cls._has_nested_data_types(inner_type): - # split on comma , to get individual data types - single_fields = cls._split_data_type(inner_type, ",") - for single_field in single_fields: - single_field = single_field.strip() - # If component type starts with a comma, the first single field - # will be an empty string. Disregard this empty string. - if not single_field: - continue - # split on whitespace to get field name and data type - field_info = cls._split_data_type(single_field, r"\s") - # check if there is a structural data type within - # overall structural data type - column_spec = cls.get_column_spec(field_info[1]) - column_type = column_spec.sqla_type if column_spec else None - if column_type is None: - column_type = types.String() - logger.info( - "Did not recognize type %s of column %s", - field_info[1], - field_info[0], - ) - if field_info[1] == "array" or field_info[1] == "row": - stack.append((field_info[0], field_info[1])) - full_parent_path = cls._get_full_name(stack) - result.append( - cls._create_column_info(full_parent_path, column_type) - ) - else: # otherwise this field is a basic data type - full_parent_path = cls._get_full_name(stack) - column_name = f"{full_parent_path}.{field_info[0]}" - result.append( - cls._create_column_info(column_name, column_type) - ) - # If the component type ends with a structural data type, do not pop - # the stack. We have run across a structural data type within the - # overall structural data type. Otherwise, we have completely parsed - # through the entire structural data type and can move on. - if not (inner_type.endswith("array") or inner_type.endswith("row")): - stack.pop() - # We have an array of row objects (i.e. array(row(...))) - elif inner_type in ("array", "row"): - # Push a dummy object to represent the structural data type - stack.append(("", inner_type)) - # We have an array of a basic data types(i.e. array(varchar)). - elif stack: - # Because it is an array of a basic data type. We have finished - # parsing the structural data type and can move on. - stack.pop() - # Unquote the column name if necessary - if formatted_parent_column_name != parent_column_name: - for index in range(original_result_len, len(result)): - result[index]["column_name"] = result[index]["column_name"].replace( - formatted_parent_column_name, parent_column_name - ) - - @classmethod - def _show_columns( - cls, - inspector: Inspector, - table: Table, - ) -> list[ResultRow]: - """ - Show presto column names - :param inspector: object that performs database schema inspection - :param table: table instance - :return: list of column objects - """ - quote = inspector.engine.dialect.identifier_preparer.quote_identifier - full_table = quote(table.table) - if table.schema: - full_table = f"{quote(table.schema)}.{full_table}" - return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() - - @classmethod - def get_columns( - cls, - inspector: Inspector, - table: Table, - options: dict[str, Any] | None = None, - ) -> list[ResultSetColumnType]: - """ - Get columns from a Presto data source. This includes handling row and - array data types - :param inspector: object that performs database schema inspection - :param table: table instance - :param options: Extra configuration options, not used by this backend - :return: a list of results that contain column info - (i.e. column name and data type) - """ - columns = cls._show_columns(inspector, table) - result: list[ResultSetColumnType] = [] - for column in columns: - # parse column if it is a row or array - if is_feature_enabled("PRESTO_EXPAND_DATA") and ( - "array" in column.Type or "row" in column.Type - ): - structural_column_index = len(result) - cls._parse_structural_column(column.Column, column.Type, result) - result[structural_column_index]["nullable"] = getattr( - column, "Null", True - ) - result[structural_column_index]["default"] = None - continue - - # otherwise column is a basic data type - column_spec = cls.get_column_spec(column.Type) - column_type = column_spec.sqla_type if column_spec else None - if column_type is None: - column_type = types.String() - logger.info( - "Did not recognize type %s of column %s", - str(column.Type), - str(column.Column), - ) - column_info = cls._create_column_info(column.Column, column_type) - column_info["nullable"] = getattr(column, "Null", True) - column_info["default"] = None - column_info["column_name"] = column.Column - result.append(column_info) - return result - @classmethod def _is_column_name_quoted(cls, column_name: str) -> bool: """ diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 143276bdc3dca..1eb4b307870d0 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -36,7 +36,7 @@ from superset import db from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.databases.utils import make_url_safe -from superset.db_engine_specs.base import BaseEngineSpec +from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns from superset.db_engine_specs.exceptions import ( SupersetDBAPIConnectionError, SupersetDBAPIDatabaseError, @@ -241,7 +241,11 @@ def _execute( execute_thread = threading.Thread( target=_execute, - args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access + args=( + execute_result, + execute_event, + current_app._get_current_object(), # pylint: disable=protected-access + ), ) execute_thread.start() @@ -433,7 +437,17 @@ def get_columns( "schema_options", expand the schema definition out to show all subfields of nested ROWs as their appropriate dotted paths. """ - base_cols = super().get_columns(inspector, table, options) + # The Trino dialect raises `NoSuchTableError` on the inspection methods when the + # table is empty. We can work around this by running a `SHOW COLUMNS FROM` query + # when that happens, using the method from the Presto base engine spec. + try: + # `SELECT * FROM information_schema.columns WHERE ...` + sqla_columns = inspector.get_columns(table.table, table.schema) + base_cols = convert_inspector_columns(sqla_columns) + except NoSuchTableError: + # `SHOW COLUMNS FROM ...` + base_cols = super().get_columns(inspector, table, options) + if not (options or {}).get("expand_rows"): return base_cols @@ -483,9 +497,6 @@ def df_to_sql( :param to_sql_kwargs: The `pandas.DataFrame.to_sql` keyword arguments :see: superset.db_engine_specs.HiveEngineSpec.df_to_sql """ - - # pylint: disable=import-outside-toplevel - if to_sql_kwargs["if_exists"] == "append": raise SupersetException("Append operation not currently supported") diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index a3920c8916ae5..9ec1ebaf00efe 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -311,3 +311,26 @@ def test_get_default_catalog(mocker: MockerFixture) -> None: database = mocker.MagicMock() assert BaseEngineSpec.get_default_catalog(database) is None + + +def test_quote_table() -> None: + """ + Test the `quote_table` function. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + dialect = sqlite.dialect() + + assert BaseEngineSpec.quote_table(Table("table"), dialect) == '"table"' + assert ( + BaseEngineSpec.quote_table(Table("table", "schema"), dialect) + == 'schema."table"' + ) + assert ( + BaseEngineSpec.quote_table(Table("table", "schema", "catalog"), dialect) + == 'catalog.schema."table"' + ) + assert ( + BaseEngineSpec.quote_table(Table("ta ble", "sche.ma", 'cata"log'), dialect) + == '"cata""log"."sche.ma"."ta ble"' + ) diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 3a2ac91ad623f..a0923e8111860 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access import copy +from collections import namedtuple from datetime import datetime from typing import Any, Optional from unittest.mock import MagicMock, Mock, patch @@ -25,7 +26,9 @@ from pytest_mock import MockerFixture from requests.exceptions import ConnectionError as RequestsConnectionError from sqlalchemy import sql, text, types +from sqlalchemy.dialects import sqlite from sqlalchemy.engine.url import make_url +from sqlalchemy.exc import NoSuchTableError from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError from trino.sqlalchemy import datatype from trino.sqlalchemy.dialect import TrinoDialect @@ -464,6 +467,64 @@ def test_get_columns(mocker: MockerFixture): _assert_columns_equal(actual, expected) +def test_get_columns_error(mocker: MockerFixture): + """ + Test that we fallback to a `SHOW COLUMNS FROM ...` query. + """ + from superset.db_engine_specs.trino import TrinoEngineSpec + + field1_type = datatype.parse_sqltype("row(a varchar, b date)") + field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") + field3_type = datatype.parse_sqltype("int") + + mock_inspector = mocker.MagicMock() + mock_inspector.engine.dialect = sqlite.dialect() + mock_inspector.get_columns.side_effect = NoSuchTableError( + "The specified table does not exist." + ) + Row = namedtuple("Row", ["Column", "Type"]) + mock_inspector.bind.execute().fetchall.return_value = [ + Row("field1", "row(a varchar, b date)"), + Row("field2", "row(r1 row(a varchar, b varchar))"), + Row("field3", "int"), + ] + + actual = TrinoEngineSpec.get_columns(mock_inspector, Table("table", "schema")) + expected = [ + ResultSetColumnType( + name="field1", + column_name="field1", + type=field1_type, + is_dttm=None, + type_generic=None, + default=None, + nullable=True, + ), + ResultSetColumnType( + name="field2", + column_name="field2", + type=field2_type, + is_dttm=None, + type_generic=None, + default=None, + nullable=True, + ), + ResultSetColumnType( + name="field3", + column_name="field3", + type=field3_type, + is_dttm=None, + type_generic=None, + default=None, + nullable=True, + ), + ] + + _assert_columns_equal(actual, expected) + + mock_inspector.bind.execute.assert_called_with('SHOW COLUMNS FROM schema."table"') + + def test_get_columns_expand_rows(mocker: MockerFixture): """Test that ROW columns are correctly expanded with expand_rows""" from superset.db_engine_specs.trino import TrinoEngineSpec @@ -536,8 +597,6 @@ def test_get_columns_expand_rows(mocker: MockerFixture): def test_get_indexes_no_table(): - from sqlalchemy.exc import NoSuchTableError - from superset.db_engine_specs.trino import TrinoEngineSpec db_mock = Mock()