From e7c73d820f83b9792dc0a093869e86ae28f83ba9 Mon Sep 17 00:00:00 2001 From: John Bodley Date: Wed, 9 Nov 2022 13:32:19 -0800 Subject: [PATCH] chore: Change get_table_names/get_view_names return type --- superset/db_engine_specs/base.py | 16 +++++----- superset/db_engine_specs/databricks.py | 12 +++---- superset/db_engine_specs/duckdb.py | 6 ++-- superset/db_engine_specs/postgres.py | 10 +++--- superset/db_engine_specs/presto.py | 28 ++++++++++------ superset/db_engine_specs/sqlite.py | 6 ++-- superset/models/core.py | 32 ++++++++++++------- superset/views/core.py | 8 ++--- tests/integration_tests/datasets/api_tests.py | 2 +- .../db_engine_specs/base_engine_spec_tests.py | 4 +-- .../db_engine_specs/postgres_tests.py | 4 +-- .../db_engine_specs/presto_tests.py | 10 +++--- 12 files changed, 76 insertions(+), 62 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 2a1363e0b6957..2d1b9d5dcb0d6 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1016,7 +1016,7 @@ def get_table_names( # pylint: disable=unused-argument database: "Database", inspector: Inspector, schema: Optional[str], - ) -> List[str]: + ) -> Set[str]: """ Get all the real table names within the specified schema. @@ -1030,13 +1030,13 @@ def get_table_names( # pylint: disable=unused-argument """ try: - tables = inspector.get_table_names(schema) + tables = set(inspector.get_table_names(schema)) except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex if schema and cls.try_remove_schema_from_table_name: - tables = [re.sub(f"^{schema}\\.", "", table) for table in tables] - return sorted(tables) + tables = {re.sub(f"^{schema}\\.", "", table) for table in tables} + return tables @classmethod def get_view_names( # pylint: disable=unused-argument @@ -1044,7 +1044,7 @@ def get_view_names( # pylint: disable=unused-argument database: "Database", inspector: Inspector, schema: Optional[str], - ) -> List[str]: + ) -> Set[str]: """ Get all the view names within the specified schema. @@ -1058,13 +1058,13 @@ def get_view_names( # pylint: disable=unused-argument """ try: - views = inspector.get_view_names(schema) + views = set(inspector.get_view_names(schema)) except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex if schema and cls.try_remove_schema_from_table_name: - views = [re.sub(f"^{schema}\\.", "", view) for view in views] - return sorted(views) + views = {re.sub(f"^{schema}\\.", "", view) for view in views} + return views @classmethod def get_table_comment( diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 90d90b9448fa7..8dce8a5940613 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, Set, TYPE_CHECKING from sqlalchemy.engine.reflection import Inspector @@ -103,9 +103,7 @@ def get_table_names( database: "Database", inspector: Inspector, schema: Optional[str], - ) -> List[str]: - tables = set(super().get_table_names(database, inspector, schema)) - views = set(cls.get_view_names(database, inspector, schema)) - actual_tables = tables - views - - return list(actual_tables) + ) -> Set[str]: + return super().get_table_names( + database, inspector, schema + ) - cls.get_view_names(database, inspector, schema) diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py index 577098a1ca572..c9eb287c9e44e 100644 --- a/superset/db_engine_specs/duckdb.py +++ b/superset/db_engine_specs/duckdb.py @@ -18,7 +18,7 @@ import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.engine.reflection import Inspector @@ -75,5 +75,5 @@ def convert_dttm( @classmethod def get_table_names( cls, database: Database, inspector: Inspector, schema: Optional[str] - ) -> List[str]: - return sorted(inspector.get_table_names(schema)) + ) -> Set[str]: + return set(inspector.get_table_names(schema)) diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index f9d450a3e9c9e..286b6e80a1ca7 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -18,7 +18,7 @@ import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON @@ -228,11 +228,11 @@ def query_cost_formatter( @classmethod def get_table_names( cls, database: "Database", inspector: PGInspector, schema: Optional[str] - ) -> List[str]: + ) -> Set[str]: """Need to consider foreign tables for PostgreSQL""" - tables = inspector.get_table_names(schema) - tables.extend(inspector.get_foreign_table_names(schema)) - return sorted(tables) + return set(inspector.get_table_names(schema)) | set( + inspector.get_foreign_table_names(schema) + ) @classmethod def convert_dttm( diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index e959eb219506a..e79e95eba14ba 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -26,7 +26,18 @@ from datetime import datetime from distutils.version import StrictVersion from textwrap import dedent -from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union +from typing import ( + Any, + cast, + Dict, + List, + Optional, + Pattern, + Set, + Tuple, + TYPE_CHECKING, + Union, +) from urllib import parse import pandas as pd @@ -396,7 +407,7 @@ def get_table_names( database: Database, inspector: Inspector, schema: Optional[str], - ) -> List[str]: + ) -> Set[str]: """ Get all the real table names within the specified schema. @@ -414,12 +425,9 @@ def get_table_names( :returns: The physical table names """ - return sorted( - list( - set(super().get_table_names(database, inspector, schema)) - - set(cls.get_view_names(database, inspector, schema)) - ) - ) + return super().get_table_names( + database, inspector, schema + ) - cls.get_view_names(database, inspector, schema) @classmethod def get_view_names( @@ -427,7 +435,7 @@ def get_view_names( database: Database, inspector: Inspector, schema: Optional[str], - ) -> List[str]: + ) -> Set[str]: """ Get all the view names within the specified schema. @@ -469,7 +477,7 @@ def get_view_names( cursor.execute(sql, params) results = cursor.fetchall() - return sorted([row[0] for row in results]) + return {row[0] for row in results} @classmethod def _create_column_info( diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py index 85442aa877363..8bd2d081ee9e3 100644 --- a/superset/db_engine_specs/sqlite.py +++ b/superset/db_engine_specs/sqlite.py @@ -16,7 +16,7 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.engine.reflection import Inspector @@ -88,6 +88,6 @@ def convert_dttm( @classmethod def get_table_names( cls, database: "Database", inspector: Inspector, schema: Optional[str] - ) -> List[str]: + ) -> Set[str]: """Need to disregard the schema for Sqlite""" - return sorted(inspector.get_table_names()) + return set(inspector.get_table_names()) diff --git a/superset/models/core.py b/superset/models/core.py index 86b9eb1bde759..28ad6702cb4eb 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -546,7 +546,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> List[Tuple[str, str]]: + ) -> Set[Tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -556,13 +556,17 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache - :return: list of tables + :return: set of tables """ try: - tables = self.db_engine_spec.get_table_names( - database=self, inspector=self.inspector, schema=schema - ) - return [(table, schema) for table in tables] + return { + (table, schema) + for table in self.db_engine_spec.get_table_names( + database=self, + inspector=self.inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @@ -576,7 +580,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> List[Tuple[str, str]]: + ) -> Set[Tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -586,13 +590,17 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache :param force: whether to force refresh the cache - :return: list of views + :return: set of views """ try: - views = self.db_engine_spec.get_view_names( - database=self, inspector=self.inspector, schema=schema - ) - return [(view, schema) for view in views] + return { + (view, schema) + for view in self.db_engine_spec.get_view_names( + database=self, + inspector=self.inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) diff --git a/superset/views/core.py b/superset/views/core.py index cc1865452a856..a87b6830460db 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1185,7 +1185,7 @@ def tables( # pylint: disable=no-self-use tables = security_manager.get_datasources_accessible_by_user( database=database, schema=schema_parsed, - datasource_names=[ + datasource_names=sorted( utils.DatasourceName(*datasource_name) for datasource_name in database.get_all_table_names_in_schema( schema=schema_parsed, @@ -1193,13 +1193,13 @@ def tables( # pylint: disable=no-self-use cache=database.table_cache_enabled, cache_timeout=database.table_cache_timeout, ) - ], + ), ) views = security_manager.get_datasources_accessible_by_user( database=database, schema=schema_parsed, - datasource_names=[ + datasource_names=sorted( utils.DatasourceName(*datasource_name) for datasource_name in database.get_all_view_names_in_schema( schema=schema_parsed, @@ -1207,7 +1207,7 @@ def tables( # pylint: disable=no-self-use cache=database.table_cache_enabled, cache_timeout=database.table_cache_timeout, ) - ], + ), ) except SupersetException as ex: return json_error_response(ex.message, ex.status) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 0175a2c3341b3..5c38330dceea0 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -765,7 +765,7 @@ def test_create_dataset_validate_view_exists( with patch.object( dialect, "get_view_names", wraps=dialect.get_view_names ) as patch_get_view_names: - patch_get_view_names.return_value = ["test_case_view"] + patch_get_view_names.return_value = {"test_case_view"} self.login(username="admin") table_data = { diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index c31a501487dda..a33b53812ce8a 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -229,11 +229,11 @@ def test_get_table_names(self): """ Make sure base engine spec removes schema name from table name ie. when try_remove_schema_from_table_name == True. """ - base_result_expected = ["table", "table_2"] + base_result_expected = {"table", "table_2"} base_result = BaseEngineSpec.get_table_names( database=mock.ANY, schema="schema", inspector=inspector ) - self.assertListEqual(base_result_expected, base_result) + self.assertSetEqual(base_result_expected, base_result) @pytest.mark.usefixtures("load_energy_table_with_slice") def test_column_datatype_to_string(self): diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 5021075fe7728..b38f72ae96f89 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -45,11 +45,11 @@ def test_get_table_names(self): inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"]) inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"]) - pg_result_expected = ["schema.table", "table_2", "table_3"] + pg_result_expected = {"schema.table", "table_2", "table_3"} pg_result = PostgresEngineSpec.get_table_names( database=mock.ANY, schema="schema", inspector=inspector ) - self.assertListEqual(pg_result_expected, pg_result) + self.assertSetEqual(pg_result_expected, pg_result) def test_time_exp_literal_no_grain(self): """ diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index d37a04645f8cb..4b82d8b53f2c0 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -55,7 +55,7 @@ def test_get_view_names_with_schema(self): ).strip(), {"schema": schema}, ) - assert result == ["a", "d"] + assert result == {"a", "d"} def test_get_view_names_without_schema(self): database = mock.MagicMock() @@ -76,7 +76,7 @@ def test_get_view_names_without_schema(self): ).strip(), {}, ) - assert result == ["a", "d"] + assert result == {"a", "d"} def verify_presto_column(self, column, expected_results): inspector = mock.Mock() @@ -669,10 +669,10 @@ def test_get_table_names( mock_get_view_names, mock_get_table_names, ): - mock_get_view_names.return_value = ["view1", "view2"] - mock_get_table_names.return_value = ["table1", "table2", "view1", "view2"] + mock_get_view_names.return_value = {"view1", "view2"} + mock_get_table_names.return_value = {"table1", "table2", "view1", "view2"} tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None) - assert tables == ["table1", "table2"] + assert tables == {"table1", "table2"} def test_get_full_name(self): names = [