From 7e54b88a519e70aee7ef3add39f0fed1d6fa7272 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 18 Nov 2022 12:41:21 -0800 Subject: [PATCH] chore: Change get_table_names/get_view_names return type (#22085) --- superset/db_engine_specs/base.py | 16 +++++----- superset/db_engine_specs/databricks.py | 12 +++---- superset/db_engine_specs/duckdb.py | 6 ++-- superset/db_engine_specs/postgres.py | 10 +++--- superset/db_engine_specs/presto.py | 28 ++++++++++------ superset/db_engine_specs/sqlite.py | 6 ++-- superset/models/core.py | 32 ++++++++++++------- superset/views/core.py | 8 ++--- tests/integration_tests/datasets/api_tests.py | 2 +- .../db_engine_specs/base_engine_spec_tests.py | 4 +-- .../db_engine_specs/postgres_tests.py | 4 +-- .../db_engine_specs/presto_tests.py | 10 +++--- 12 files changed, 76 insertions(+), 62 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 96dc8eeecfebe..87951d396ef10 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1034,7 +1034,7 @@ def get_table_names( # pylint: disable=unused-argument database: "Database", inspector: Inspector, schema: Optional[str], - ) -> List[str]: + ) -> Set[str]: """ Get all the real table names within the specified schema. @@ -1048,13 +1048,13 @@ def get_table_names( # pylint: disable=unused-argument """ try: - tables = inspector.get_table_names(schema) + tables = set(inspector.get_table_names(schema)) except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex if schema and cls.try_remove_schema_from_table_name: - tables = [re.sub(f"^{schema}\\.", "", table) for table in tables] - return sorted(tables) + tables = {re.sub(f"^{schema}\\.", "", table) for table in tables} + return tables @classmethod def get_view_names( # pylint: disable=unused-argument @@ -1062,7 +1062,7 @@ def get_view_names( # pylint: disable=unused-argument database: "Database", inspector: Inspector, schema: Optional[str], - ) -> List[str]: + ) -> Set[str]: """ Get all the view names within the specified schema. @@ -1076,13 +1076,13 @@ def get_view_names( # pylint: disable=unused-argument """ try: - views = inspector.get_view_names(schema) + views = set(inspector.get_view_names(schema)) except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex if schema and cls.try_remove_schema_from_table_name: - views = [re.sub(f"^{schema}\\.", "", view) for view in views] - return sorted(views) + views = {re.sub(f"^{schema}\\.", "", view) for view in views} + return views @classmethod def get_table_comment( diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 90d90b9448fa7..8dce8a5940613 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, Set, TYPE_CHECKING from sqlalchemy.engine.reflection import Inspector @@ -103,9 +103,7 @@ def get_table_names( database: "Database", inspector: Inspector, schema: Optional[str], - ) -> List[str]: - tables = set(super().get_table_names(database, inspector, schema)) - views = set(cls.get_view_names(database, inspector, schema)) - actual_tables = tables - views - - return list(actual_tables) + ) -> Set[str]: + return super().get_table_names( + database, inspector, schema + ) - cls.get_view_names(database, inspector, schema) diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py index 577098a1ca572..c9eb287c9e44e 100644 --- a/superset/db_engine_specs/duckdb.py +++ b/superset/db_engine_specs/duckdb.py @@ -18,7 +18,7 @@ import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.engine.reflection import Inspector @@ -75,5 +75,5 @@ def convert_dttm( @classmethod def get_table_names( cls, database: Database, inspector: Inspector, schema: Optional[str] - ) -> List[str]: - return sorted(inspector.get_table_names(schema)) + ) -> Set[str]: + return set(inspector.get_table_names(schema)) diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index f9d450a3e9c9e..286b6e80a1ca7 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -18,7 +18,7 @@ import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON @@ -228,11 +228,11 @@ def query_cost_formatter( @classmethod def get_table_names( cls, database: "Database", inspector: PGInspector, schema: Optional[str] - ) -> List[str]: + ) -> Set[str]: """Need to consider foreign tables for PostgreSQL""" - tables = inspector.get_table_names(schema) - tables.extend(inspector.get_foreign_table_names(schema)) - return sorted(tables) + return set(inspector.get_table_names(schema)) | set( + inspector.get_foreign_table_names(schema) + ) @classmethod def convert_dttm( diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index b513db0a61958..675503973485a 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -26,7 +26,18 @@ from datetime import datetime from distutils.version import StrictVersion from textwrap import dedent -from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union +from typing import ( + Any, + cast, + Dict, + List, + Optional, + Pattern, + Set, + Tuple, + TYPE_CHECKING, + Union, +) from urllib import parse import pandas as pd @@ -396,7 +407,7 @@ def get_table_names( database: Database, inspector: Inspector, schema: Optional[str], - ) -> List[str]: + ) -> Set[str]: """ Get all the real table names within the specified schema. @@ -414,12 +425,9 @@ def get_table_names( :returns: The physical table names """ - return sorted( - list( - set(super().get_table_names(database, inspector, schema)) - - set(cls.get_view_names(database, inspector, schema)) - ) - ) + return super().get_table_names( + database, inspector, schema + ) - cls.get_view_names(database, inspector, schema) @classmethod def get_view_names( @@ -427,7 +435,7 @@ def get_view_names( database: Database, inspector: Inspector, schema: Optional[str], - ) -> List[str]: + ) -> Set[str]: """ Get all the view names within the specified schema. @@ -468,7 +476,7 @@ def get_view_names( cursor.execute(sql, params) results = cursor.fetchall() - return sorted([row[0] for row in results]) + return {row[0] for row in results} @classmethod def _create_column_info( diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py index 85442aa877363..8bd2d081ee9e3 100644 --- a/superset/db_engine_specs/sqlite.py +++ b/superset/db_engine_specs/sqlite.py @@ -16,7 +16,7 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.engine.reflection import Inspector @@ -88,6 +88,6 @@ def convert_dttm( @classmethod def get_table_names( cls, database: "Database", inspector: Inspector, schema: Optional[str] - ) -> List[str]: + ) -> Set[str]: """Need to disregard the schema for Sqlite""" - return sorted(inspector.get_table_names()) + return set(inspector.get_table_names()) diff --git a/superset/models/core.py b/superset/models/core.py index 020f04f28af02..2ac824e88c326 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -543,7 +543,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> List[Tuple[str, str]]: + ) -> Set[Tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -553,13 +553,17 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache - :return: list of tables + :return: The table/schema pairs """ try: - tables = self.db_engine_spec.get_table_names( - database=self, inspector=self.inspector, schema=schema - ) - return [(table, schema) for table in tables] + return { + (table, schema) + for table in self.db_engine_spec.get_table_names( + database=self, + inspector=self.inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @@ -573,7 +577,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> List[Tuple[str, str]]: + ) -> Set[Tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -583,13 +587,17 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache - :return: list of views + :return: set of views """ try: - views = self.db_engine_spec.get_view_names( - database=self, inspector=self.inspector, schema=schema - ) - return [(view, schema) for view in views] + return { + (view, schema) + for view in self.db_engine_spec.get_view_names( + database=self, + inspector=self.inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) diff --git a/superset/views/core.py b/superset/views/core.py index edde19871a32a..534f8f667d707 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1173,7 +1173,7 @@ def tables( # pylint: disable=no-self-use tables = security_manager.get_datasources_accessible_by_user( database=database, schema=schema_parsed, - datasource_names=[ + datasource_names=sorted( utils.DatasourceName(*datasource_name) for datasource_name in database.get_all_table_names_in_schema( schema=schema_parsed, @@ -1181,13 +1181,13 @@ def tables( # pylint: disable=no-self-use cache=database.table_cache_enabled, cache_timeout=database.table_cache_timeout, ) - ], + ), ) views = security_manager.get_datasources_accessible_by_user( database=database, schema=schema_parsed, - datasource_names=[ + datasource_names=sorted( utils.DatasourceName(*datasource_name) for datasource_name in database.get_all_view_names_in_schema( schema=schema_parsed, @@ -1195,7 +1195,7 @@ def tables( # pylint: disable=no-self-use cache=database.table_cache_enabled, cache_timeout=database.table_cache_timeout, ) - ], + ), ) except SupersetException as ex: return json_error_response(ex.message, ex.status) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 33243a801cdb0..af3a956834aac 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -767,7 +767,7 @@ def test_create_dataset_validate_view_exists( with patch.object( dialect, "get_view_names", wraps=dialect.get_view_names ) as patch_get_view_names: - patch_get_view_names.return_value = ["test_case_view"] + patch_get_view_names.return_value = {"test_case_view"} self.login(username="admin") table_data = { diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index c31a501487dda..0d945f8edea43 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -229,11 +229,11 @@ def test_get_table_names(self): """ Make sure base engine spec removes schema name from table name ie. when try_remove_schema_from_table_name == True. """ - base_result_expected = ["table", "table_2"] + base_result_expected = {"table", "table_2"} base_result = BaseEngineSpec.get_table_names( database=mock.ANY, schema="schema", inspector=inspector ) - self.assertListEqual(base_result_expected, base_result) + assert base_result_expected == base_result @pytest.mark.usefixtures("load_energy_table_with_slice") def test_column_datatype_to_string(self): diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 5021075fe7728..a9dbfa515f602 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -45,11 +45,11 @@ def test_get_table_names(self): inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"]) inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"]) - pg_result_expected = ["schema.table", "table_2", "table_3"] + pg_result_expected = {"schema.table", "table_2", "table_3"} pg_result = PostgresEngineSpec.get_table_names( database=mock.ANY, schema="schema", inspector=inspector ) - self.assertListEqual(pg_result_expected, pg_result) + assert pg_result_expected == pg_result def test_time_exp_literal_no_grain(self): """ diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index a38617e8a9a85..4a76d59a46faf 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -56,7 +56,7 @@ def test_get_view_names_with_schema(self): ).strip(), {"schema": schema}, ) - assert result == ["a", "d"] + assert result == {"a", "d"} def test_get_view_names_without_schema(self): database = mock.MagicMock() @@ -77,7 +77,7 @@ def test_get_view_names_without_schema(self): ).strip(), {}, ) - assert result == ["a", "d"] + assert result == {"a", "d"} def verify_presto_column(self, column, expected_results): inspector = mock.Mock() @@ -670,10 +670,10 @@ def test_get_table_names( mock_get_view_names, mock_get_table_names, ): - mock_get_view_names.return_value = ["view1", "view2"] - mock_get_table_names.return_value = ["table1", "table2", "view1", "view2"] + mock_get_view_names.return_value = {"view1", "view2"} + mock_get_table_names.return_value = {"table1", "table2", "view1", "view2"} tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None) - assert tables == ["table1", "table2"] + assert tables == {"table1", "table2"} def test_get_full_name(self): names = [