Skip to content

Commit

Permalink
feat: catalog support for Databricks native (apache#28394)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored May 9, 2024
1 parent e516bba commit f29e1e4
Show file tree
Hide file tree
Showing 10 changed files with 442 additions and 35 deletions.
18 changes: 12 additions & 6 deletions superset-frontend/src/components/DatabaseSelector/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ export default function DatabaseSelector({
const showCatalogSelector = !!db?.allow_multi_catalog;
const [currentDb, setCurrentDb] = useState<DatabaseValue | undefined>();
const [currentCatalog, setCurrentCatalog] = useState<
CatalogOption | undefined
CatalogOption | null | undefined
>(catalog ? { label: catalog, value: catalog, title: catalog } : undefined);
const catalogRef = useRef(catalog);
catalogRef.current = catalog;
Expand Down Expand Up @@ -265,7 +265,7 @@ export default function DatabaseSelector({

const schemaOptions = schemaData || EMPTY_SCHEMA_OPTIONS;

function changeCatalog(catalog: CatalogOption | undefined) {
function changeCatalog(catalog: CatalogOption | null | undefined) {
setCurrentCatalog(catalog);
setCurrentSchema(undefined);
if (onCatalogChange && catalog?.value !== catalogRef.current) {
Expand All @@ -280,7 +280,9 @@ export default function DatabaseSelector({
} = useCatalogs({
dbId: currentDb?.value,
onSuccess: (catalogs, isFetched) => {
if (catalogs.length === 1) {
if (!showCatalogSelector) {
changeCatalog(null);
} else if (catalogs.length === 1) {
changeCatalog(catalogs[0]);
} else if (
!catalogs.find(
Expand All @@ -290,11 +292,15 @@ export default function DatabaseSelector({
changeCatalog(undefined);
}

if (isFetched) {
if (showCatalogSelector && isFetched) {
addSuccessToast('List refreshed');
}
},
onError: () => handleError(t('There was an error loading the catalogs')),
onError: () => {
if (showCatalogSelector) {
handleError(t('There was an error loading the catalogs'));
}
},
});

const catalogOptions = catalogData || EMPTY_CATALOG_OPTIONS;
Expand Down Expand Up @@ -365,7 +371,7 @@ export default function DatabaseSelector({
onChange={item => changeCatalog(item as CatalogOption)}
options={catalogOptions}
showSearch
value={currentCatalog}
value={currentCatalog || undefined}
/>,
refreshIcon,
);
Expand Down
65 changes: 64 additions & 1 deletion superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from superset.models.core import Database


#
class DatabricksBaseSchema(Schema):
"""
Fields that are required for both Databricks drivers that uses a
Expand Down Expand Up @@ -371,6 +370,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
"extra",
}

supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True

@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksNativeParametersType, *_
Expand Down Expand Up @@ -428,6 +429,35 @@ def parameters_json_schema(cls) -> Any:
spec.components.schema(cls.__name__, schema=cls.properties_schema)
return spec.to_dict()["components"]["schemas"][cls.__name__]

@classmethod
def get_default_catalog(
cls,
database: Database,
) -> str | None:
with database.get_inspector() as inspector:
return inspector.bind.execute("SELECT current_catalog()").scalar()

@classmethod
def get_prequeries(
cls,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
prequeries = []
if catalog:
prequeries.append(f"USE CATALOG {catalog}")
if schema:
prequeries.append(f"USE SCHEMA {schema}")
return prequeries

@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}


class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
engine = "databricks"
Expand Down Expand Up @@ -455,6 +485,8 @@ class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec):
"http_path_field",
}

supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True

@classmethod
def build_sqlalchemy_uri( # type: ignore
cls, parameters: DatabricksPythonConnectorParametersType, *_
Expand Down Expand Up @@ -502,3 +534,34 @@ def get_parameters_from_uri( # type: ignore
"default_schema": query["schema"],
"encryption": encryption,
}

@classmethod
def get_default_catalog(
cls,
database: Database,
) -> str | None:
return database.url_object.query.get("catalog")

@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}

@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
if catalog:
uri = uri.update_query_dict({"catalog": catalog})

if schema:
uri = uri.update_query_dict({"schema": schema})

return uri, connect_args
163 changes: 140 additions & 23 deletions superset/migrations/shared/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,84 @@
from __future__ import annotations

import logging
from typing import Any, Type

import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base

from superset import db, security_manager
from superset.daos.database import DatabaseDAO
from superset.migrations.shared.security_converge import add_pvms, ViewMenu
from superset.models.core import Database

logger = logging.getLogger(__name__)


def upgrade_schema_perms(engine: str | None = None) -> None:
Base: Type[Any] = declarative_base()


class SqlaTable(Base):
__tablename__ = "tables"

id = sa.Column(sa.Integer, primary_key=True)
database_id = sa.Column(sa.Integer, nullable=False)
schema_perm = sa.Column(sa.String(1000))
schema = sa.Column(sa.String(255))
catalog = sa.Column(sa.String(256), nullable=True, default=None)


class Query(Base):
__tablename__ = "query"

id = sa.Column(sa.Integer, primary_key=True)
database_id = sa.Column(sa.Integer, nullable=False)
catalog = sa.Column(sa.String(256), nullable=True, default=None)


class SavedQuery(Base):
__tablename__ = "saved_query"

id = sa.Column(sa.Integer, primary_key=True)
db_id = sa.Column(sa.Integer, nullable=False)
catalog = sa.Column(sa.String(256), nullable=True, default=None)


class TabState(Base):
__tablename__ = "tab_state"

id = sa.Column(sa.Integer, primary_key=True)
database_id = sa.Column(sa.Integer, nullable=False)
catalog = sa.Column(sa.String(256), nullable=True, default=None)


class TableSchema(Base):
__tablename__ = "table_schema"

id = sa.Column(sa.Integer, primary_key=True)
database_id = sa.Column(sa.Integer, nullable=False)
catalog = sa.Column(sa.String(256), nullable=True, default=None)


class Slice(Base):
__tablename__ = "slices"

id = sa.Column(sa.Integer, primary_key=True)
datasource_id = sa.Column(sa.Integer)
datasource_type = sa.Column(sa.String(200))
schema_perm = sa.Column(sa.String(1000))


def upgrade_catalog_perms(engine: str | None = None) -> None:
"""
Update schema permissions to include the catalog part.
Update models when catalogs are introduced in a DB engine spec.
When an existing DB engine spec starts to support catalogs we need to:
- Add a `catalog_access` permission for each catalog.
- Populate the `catalog` field with the default catalog for each related model.
- Update `schema_perm` to include the default catalog.
Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
introduction of catalogs, any existing permissions need to be renamed to include the
catalog: `[db].[catalog].[schema]`.
"""
bind = op.get_bind()
session = db.Session(bind=bind)
Expand All @@ -46,6 +107,16 @@ def upgrade_schema_perms(engine: str | None = None) -> None:
continue

catalog = database.get_default_catalog()
if catalog is None:
continue

perm = security_manager.get_catalog_perm(
database.database_name,
catalog,
)
add_pvms(session, {perm: ("catalog_access",)})

# update schema_perms
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
Expand All @@ -57,29 +128,47 @@ def upgrade_schema_perms(engine: str | None = None) -> None:
None,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
perm,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.view_menu.name = security_manager.get_schema_perm(
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)

# update existing models
models = [
(Query, "database_id"),
(SavedQuery, "db_id"),
(TabState, "database_id"),
(TableSchema, "database_id"),
(SqlaTable, "database_id"),
]
for model, column in models:
for instance in session.query(model).filter(
getattr(model, column) == database.id
):
instance.catalog = catalog

for table in session.query(SqlaTable).filter_by(database_id=database.id):
schema_perm = security_manager.get_schema_perm(
database.database_name,
catalog,
table.schema,
)
table.schema_perm = schema_perm
for chart in session.query(Slice).filter_by(
datasource_id=table.id,
datasource_type="table",
):
chart.schema_perm = schema_perm

session.commit()


def downgrade_schema_perms(engine: str | None = None) -> None:
def downgrade_catalog_perms(engine: str | None = None) -> None:
"""
Update schema permissions to not have the catalog part.
Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
introduction of catalogs, any existing permissions need to be renamed to include the
catalog: `[db].[catalog].[schema]`.
This helped function reverts the process.
Reverse the process of `upgrade_catalog_perms`.
"""
bind = op.get_bind()
session = db.Session(bind=bind)
Expand All @@ -91,6 +180,10 @@ def downgrade_schema_perms(engine: str | None = None) -> None:
continue

catalog = database.get_default_catalog()
if catalog is None:
continue

# update schema_perms
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
Expand All @@ -102,15 +195,39 @@ def downgrade_schema_perms(engine: str | None = None) -> None:
catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
perm,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.view_menu.name = security_manager.get_schema_perm(
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)

# update existing models
models = [
(Query, "database_id"),
(SavedQuery, "db_id"),
(TabState, "database_id"),
(TableSchema, "database_id"),
(SqlaTable, "database_id"),
]
for model, column in models:
for instance in session.query(model).filter(
getattr(model, column) == database.id
):
instance.catalog = None

for table in session.query(SqlaTable).filter_by(database_id=database.id):
schema_perm = security_manager.get_schema_perm(
database.database_name,
None,
table.schema,
)
table.schema_perm = schema_perm
for chart in session.query(Slice).filter_by(
datasource_id=table.id,
datasource_type="table",
):
chart.schema_perm = schema_perm

session.commit()
Loading

0 comments on commit f29e1e4

Please sign in to comment.