From f29e1e4c29a46f7d607cfa59adb8bb21d107091c Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 9 May 2024 17:41:15 -0400 Subject: [PATCH] feat: catalog support for Databricks native (#28394) --- .../src/components/DatabaseSelector/index.tsx | 18 +- superset/db_engine_specs/databricks.py | 65 ++++++- superset/migrations/shared/catalogs.py | 163 +++++++++++++++--- ...58d051681a3b_add_catalog_perm_to_tables.py | 8 +- ...81be5b6b74_enable_catalog_in_databricks.py | 40 +++++ .../pandas_postprocessing/contribution.py | 1 + tests/integration_tests/datasets/api_tests.py | 2 +- .../db_engine_specs/test_databricks.py | 19 ++ .../unit_tests/migrations/shared/__init__.py | 16 ++ .../migrations/shared/catalogs_test.py | 145 ++++++++++++++++ 10 files changed, 442 insertions(+), 35 deletions(-) create mode 100644 superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py create mode 100644 tests/unit_tests/migrations/shared/__init__.py create mode 100644 tests/unit_tests/migrations/shared/catalogs_test.py diff --git a/superset-frontend/src/components/DatabaseSelector/index.tsx b/superset-frontend/src/components/DatabaseSelector/index.tsx index 6eb1340d5bcc3..23767ba9f7bff 100644 --- a/superset-frontend/src/components/DatabaseSelector/index.tsx +++ b/superset-frontend/src/components/DatabaseSelector/index.tsx @@ -143,7 +143,7 @@ export default function DatabaseSelector({ const showCatalogSelector = !!db?.allow_multi_catalog; const [currentDb, setCurrentDb] = useState(); const [currentCatalog, setCurrentCatalog] = useState< - CatalogOption | undefined + CatalogOption | null | undefined >(catalog ? { label: catalog, value: catalog, title: catalog } : undefined); const catalogRef = useRef(catalog); catalogRef.current = catalog; @@ -265,7 +265,7 @@ export default function DatabaseSelector({ const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS; - function changeCatalog(catalog: CatalogOption | undefined) { + function changeCatalog(catalog: CatalogOption | null | undefined) { setCurrentCatalog(catalog); setCurrentSchema(undefined); if (onCatalogChange && catalog?.value !== catalogRef.current) { @@ -280,7 +280,9 @@ export default function DatabaseSelector({ } = useCatalogs({ dbId: currentDb?.value, onSuccess: (catalogs, isFetched) => { - if (catalogs.length === 1) { + if (!showCatalogSelector) { + changeCatalog(null); + } else if (catalogs.length === 1) { changeCatalog(catalogs[0]); } else if ( !catalogs.find( @@ -290,11 +292,15 @@ export default function DatabaseSelector({ changeCatalog(undefined); } - if (isFetched) { + if (showCatalogSelector && isFetched) { addSuccessToast('List refreshed'); } }, - onError: () => handleError(t('There was an error loading the catalogs')), + onError: () => { + if (showCatalogSelector) { + handleError(t('There was an error loading the catalogs')); + } + }, }); const catalogOptions = catalogData || EMPTY_CATALOG_OPTIONS; @@ -365,7 +371,7 @@ export default function DatabaseSelector({ onChange={item => changeCatalog(item as CatalogOption)} options={catalogOptions} showSearch - value={currentCatalog} + value={currentCatalog || undefined} />, refreshIcon, ); diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 6fc753c00e710..3f72931626893 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -39,7 +39,6 @@ from superset.models.core import Database -# class DatabricksBaseSchema(Schema): """ Fields that are required for both Databricks drivers that uses a @@ -371,6 +370,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec): "extra", } + supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True + @classmethod def build_sqlalchemy_uri( # type: ignore cls, parameters: DatabricksNativeParametersType, *_ @@ -428,6 +429,35 @@ def parameters_json_schema(cls) -> Any: spec.components.schema(cls.__name__, schema=cls.properties_schema) return spec.to_dict()["components"]["schemas"][cls.__name__] + @classmethod + def get_default_catalog( + cls, + database: Database, + ) -> str | None: + with database.get_inspector() as inspector: + return inspector.bind.execute("SELECT current_catalog()").scalar() + + @classmethod + def get_prequeries( + cls, + catalog: str | None = None, + schema: str | None = None, + ) -> list[str]: + prequeries = [] + if catalog: + prequeries.append(f"USE CATALOG {catalog}") + if schema: + prequeries.append(f"USE SCHEMA {schema}") + return prequeries + + @classmethod + def get_catalog_names( + cls, + database: Database, + inspector: Inspector, + ) -> set[str]: + return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} + class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec): engine = "databricks" @@ -455,6 +485,8 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec): "http_path_field", } + supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True + @classmethod def build_sqlalchemy_uri( # type: ignore cls, parameters: DatabricksPythonConnectorParametersType, *_ @@ -502,3 +534,34 @@ def get_parameters_from_uri( # type: ignore "default_schema": query["schema"], "encryption": encryption, } + + @classmethod + def get_default_catalog( + cls, + database: Database, + ) -> str | None: + return database.url_object.query.get("catalog") + + @classmethod + def get_catalog_names( + cls, + database: Database, + inspector: Inspector, + ) -> set[str]: + return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} + + @classmethod + def adjust_engine_params( + cls, + uri: URL, + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: + if catalog: + uri = uri.update_query_dict({"catalog": catalog}) + + if schema: + uri = uri.update_query_dict({"schema": schema}) + + return uri, connect_args diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index 5d01ecfbfbcc5..4b13d6043539a 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -18,23 +18,84 @@ from __future__ import annotations import logging +from typing import Any, Type +import sqlalchemy as sa from alembic import op +from sqlalchemy.ext.declarative import declarative_base from superset import db, security_manager from superset.daos.database import DatabaseDAO +from superset.migrations.shared.security_converge import add_pvms, ViewMenu from superset.models.core import Database logger = logging.getLogger(__name__) -def upgrade_schema_perms(engine: str | None = None) -> None: +Base: Type[Any] = declarative_base() + + +class SqlaTable(Base): + __tablename__ = "tables" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, nullable=False) + schema_perm = sa.Column(sa.String(1000)) + schema = sa.Column(sa.String(255)) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class Query(Base): + __tablename__ = "query" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, nullable=False) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class SavedQuery(Base): + __tablename__ = "saved_query" + + id = sa.Column(sa.Integer, primary_key=True) + db_id = sa.Column(sa.Integer, nullable=False) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class TabState(Base): + __tablename__ = "tab_state" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, nullable=False) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class TableSchema(Base): + __tablename__ = "table_schema" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column(sa.Integer, nullable=False) + catalog = sa.Column(sa.String(256), nullable=True, default=None) + + +class Slice(Base): + __tablename__ = "slices" + + id = sa.Column(sa.Integer, primary_key=True) + datasource_id = sa.Column(sa.Integer) + datasource_type = sa.Column(sa.String(200)) + schema_perm = sa.Column(sa.String(1000)) + + +def upgrade_catalog_perms(engine: str | None = None) -> None: """ - Update schema permissions to include the catalog part. + Update models when catalogs are introduced in a DB engine spec. + + When an existing DB engine spec starts to support catalogs we need to: + + - Add a `catalog_access` permission for each catalog. + - Populate the `catalog` field with the default catalog for each related model. + - Update `schema_perm` to include the default catalog. - Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the - introduction of catalogs, any existing permissions need to be renamed to include the - catalog: `[db].[catalog].[schema]`. """ bind = op.get_bind() session = db.Session(bind=bind) @@ -46,6 +107,16 @@ def upgrade_schema_perms(engine: str | None = None) -> None: continue catalog = database.get_default_catalog() + if catalog is None: + continue + + perm = security_manager.get_catalog_perm( + database.database_name, + catalog, + ) + add_pvms(session, {perm: ("catalog_access",)}) + + # update schema_perms ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) for schema in database.get_all_schema_names( catalog=catalog, @@ -57,29 +128,47 @@ def upgrade_schema_perms(engine: str | None = None) -> None: None, schema, ) - existing_pvm = security_manager.find_permission_view_menu( - "schema_access", - perm, - ) + existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() if existing_pvm: - existing_pvm.view_menu.name = security_manager.get_schema_perm( + existing_pvm.name = security_manager.get_schema_perm( database.database_name, catalog, schema, ) + # update existing models + models = [ + (Query, "database_id"), + (SavedQuery, "db_id"), + (TabState, "database_id"), + (TableSchema, "database_id"), + (SqlaTable, "database_id"), + ] + for model, column in models: + for instance in session.query(model).filter( + getattr(model, column) == database.id + ): + instance.catalog = catalog + + for table in session.query(SqlaTable).filter_by(database_id=database.id): + schema_perm = security_manager.get_schema_perm( + database.database_name, + catalog, + table.schema, + ) + table.schema_perm = schema_perm + for chart in session.query(Slice).filter_by( + datasource_id=table.id, + datasource_type="table", + ): + chart.schema_perm = schema_perm + session.commit() -def downgrade_schema_perms(engine: str | None = None) -> None: +def downgrade_catalog_perms(engine: str | None = None) -> None: """ - Update schema permissions to not have the catalog part. - - Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the - introduction of catalogs, any existing permissions need to be renamed to include the - catalog: `[db].[catalog].[schema]`. - - This helped function reverts the process. + Reverse the process of `upgrade_catalog_perms`. """ bind = op.get_bind() session = db.Session(bind=bind) @@ -91,6 +180,10 @@ def downgrade_schema_perms(engine: str | None = None) -> None: continue catalog = database.get_default_catalog() + if catalog is None: + continue + + # update schema_perms ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) for schema in database.get_all_schema_names( catalog=catalog, @@ -102,15 +195,39 @@ def downgrade_schema_perms(engine: str | None = None) -> None: catalog, schema, ) - existing_pvm = security_manager.find_permission_view_menu( - "schema_access", - perm, - ) + existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() if existing_pvm: - existing_pvm.view_menu.name = security_manager.get_schema_perm( + existing_pvm.name = security_manager.get_schema_perm( database.database_name, None, schema, ) + # update existing models + models = [ + (Query, "database_id"), + (SavedQuery, "db_id"), + (TabState, "database_id"), + (TableSchema, "database_id"), + (SqlaTable, "database_id"), + ] + for model, column in models: + for instance in session.query(model).filter( + getattr(model, column) == database.id + ): + instance.catalog = None + + for table in session.query(SqlaTable).filter_by(database_id=database.id): + schema_perm = security_manager.get_schema_perm( + database.database_name, + None, + table.schema, + ) + table.schema_perm = schema_perm + for chart in session.query(Slice).filter_by( + datasource_id=table.id, + datasource_type="table", + ): + chart.schema_perm = schema_perm + session.commit() diff --git a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py index 17b33e1d0a8f8..f8f782474482a 100644 --- a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py +++ b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py @@ -26,8 +26,8 @@ from alembic import op from superset.migrations.shared.catalogs import ( - downgrade_schema_perms, - upgrade_schema_perms, + downgrade_catalog_perms, + upgrade_catalog_perms, ) # revision identifiers, used by Alembic. @@ -44,10 +44,10 @@ def upgrade(): "slices", sa.Column("catalog_perm", sa.String(length=1000), nullable=True), ) - upgrade_schema_perms(engine="postgresql") + upgrade_catalog_perms(engine="postgresql") def downgrade(): op.drop_column("slices", "catalog_perm") op.drop_column("tables", "catalog_perm") - downgrade_schema_perms(engine="postgresql") + downgrade_catalog_perms(engine="postgresql") diff --git a/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py new file mode 100644 index 0000000000000..f39d6fa0d6175 --- /dev/null +++ b/superset/migrations/versions/2024-05-08_19-33_4081be5b6b74_enable_catalog_in_databricks.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Enable catalog in Databricks + +Revision ID: 4081be5b6b74 +Revises: 645bb206f96c +Create Date: 2024-05-08 19:33:18.311411 + +""" + +from superset.migrations.shared.catalogs import ( + downgrade_catalog_perms, + upgrade_catalog_perms, +) + +# revision identifiers, used by Alembic. +revision = "4081be5b6b74" +down_revision = "645bb206f96c" + + +def upgrade(): + upgrade_catalog_perms(engine="databricks") + + +def downgrade(): + downgrade_catalog_perms(engine="databricks") diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py index 46144ec019402..ad8b070869cf4 100644 --- a/superset/utils/pandas_postprocessing/contribution.py +++ b/superset/utils/pandas_postprocessing/contribution.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations from decimal import Decimal diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 543a834793bdd..6cc3cc8828dc0 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -368,8 +368,8 @@ def test_get_dataset_item(self): expected_result = { "cache_timeout": None, "database": { - "backend": main_db.backend, "allow_multi_catalog": False, + "backend": main_db.backend, "database_name": "examples", "id": 1, }, diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py index 8709833d3f444..204faed445356 100644 --- a/tests/unit_tests/db_engine_specs/test_databricks.py +++ b/tests/unit_tests/db_engine_specs/test_databricks.py @@ -245,3 +245,22 @@ def test_convert_dttm( from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec as spec assert_convert_dttm(spec, target_type, expected_result, dttm) + + +def test_get_prequeries() -> None: + """ + Test the ``get_prequeries`` method. + """ + from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec + + assert DatabricksNativeEngineSpec.get_prequeries() == [] + assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [ + "USE SCHEMA test", + ] + assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [ + "USE CATALOG test", + ] + assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", schema="bar") == [ + "USE CATALOG foo", + "USE SCHEMA bar", + ] diff --git a/tests/unit_tests/migrations/shared/__init__.py b/tests/unit_tests/migrations/shared/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/migrations/shared/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/migrations/shared/catalogs_test.py b/tests/unit_tests/migrations/shared/catalogs_test.py new file mode 100644 index 0000000000000..ca715bec94151 --- /dev/null +++ b/tests/unit_tests/migrations/shared/catalogs_test.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pytest_mock import MockerFixture +from sqlalchemy.orm.session import Session + +from superset.migrations.shared.catalogs import ( + downgrade_catalog_perms, + upgrade_catalog_perms, +) +from superset.migrations.shared.security_converge import ViewMenu + + +def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: + """ + Test the `upgrade_catalog_perms` function. + + The function is called when catalogs are introduced into a new DB engine spec. When + that happens, we need to update the `catalog` attribute so it points to the default + catalog, instead of being `NULL`. We also need to update `schema_perms` to include + the default catalog. + """ + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import Database + from superset.models.slice import Slice + from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState + + engine = session.get_bind() + Database.metadata.create_all(engine) + + mocker.patch("superset.migrations.shared.catalogs.op") + db = mocker.patch("superset.migrations.shared.catalogs.db") + db.Session.return_value = session + + mocker.patch.object( + Database, + "get_all_schema_names", + return_value=["public", "information_schema"], + ) + + database = Database( + database_name="my_db", + sqlalchemy_uri="postgresql://localhost/db", + ) + dataset = SqlaTable( + table_name="my_table", + database=database, + catalog=None, + schema="public", + schema_perm="[my_db].[public]", + ) + session.add(dataset) + session.commit() + + chart = Slice( + slice_name="my_chart", + datasource_type="table", + datasource_id=dataset.id, + ) + query = Query( + client_id="foo", + database=database, + catalog=None, + schema="public", + ) + saved_query = SavedQuery( + database=database, + sql="SELECT * FROM public.t", + catalog=None, + schema="public", + ) + tab_state = TabState( + database=database, + catalog=None, + schema="public", + ) + table_schema = TableSchema( + database=database, + catalog=None, + schema="public", + ) + session.add_all([chart, query, saved_query, tab_state, table_schema]) + session.commit() + + # before migration + assert dataset.catalog is None + assert query.catalog is None + assert saved_query.catalog is None + assert tab_state.catalog is None + assert table_schema.catalog is None + assert dataset.schema_perm == "[my_db].[public]" + assert chart.schema_perm == "[my_db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[public]",), + ] + + upgrade_catalog_perms() + + # after migration + assert dataset.catalog == "db" + assert query.catalog == "db" + assert saved_query.catalog == "db" + assert tab_state.catalog == "db" + assert table_schema.catalog == "db" + assert dataset.schema_perm == "[my_db].[db].[public]" + assert chart.schema_perm == "[my_db].[db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[db].[public]",), + ("[my_db].[db]",), + ] + + downgrade_catalog_perms() + + # revert + assert dataset.catalog is None + assert query.catalog is None + assert saved_query.catalog is None + assert tab_state.catalog is None + assert table_schema.catalog is None + assert dataset.schema_perm == "[my_db].[public]" + assert chart.schema_perm == "[my_db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[public]",), + ("[my_db].[db]",), + ]