Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make catalog migration lenient #29549

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 79 additions & 38 deletions superset/migrations/shared/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

from superset import db, security_manager
from superset.daos.database import DatabaseDAO
Expand Down Expand Up @@ -86,6 +87,24 @@ class Slice(Base):
schema_perm = sa.Column(sa.String(1000))


def get_schemas(database_name: str) -> list[str]:
"""
Read all known schemas from the schema permissions.
"""
query = f"""
SELECT
avm.name
FROM ab_view_menu avm
JOIN ab_permission_view apv ON avm.id = apv.view_menu_id
JOIN ab_permission ap ON apv.permission_id = ap.id
WHERE
avm.name LIKE '[{database_name}]%' AND
ap.name = 'schema_access';
"""
# [PostgreSQL].[postgres].[public] => public
return sorted({row[0].split(".")[-1][1:-1] for row in op.execute(query)})


def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Update models when catalogs are introduced in a DB engine spec.
Expand Down Expand Up @@ -116,25 +135,7 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
)
add_pvms(session, {perm: ("catalog_access",)})

# update schema_perms
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
perm = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
upgrade_schema_perms(database, catalog, session)

# update existing models
models = [
Expand Down Expand Up @@ -166,6 +167,35 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
session.commit()


def upgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
"""
Rename existing schema permissions to include the catalog.
"""
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
try:
schemas = database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
)
except Exception: # pylint: disable=broad-except
schemas = get_schemas(database.database_name)

for schema in schemas:
perm = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)


def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Reverse the process of `upgrade_catalog_perms`.
Expand All @@ -183,25 +213,7 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
if catalog is None:
continue

# update schema_perms
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
perm = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
downgrade_schema_perms(database, catalog, session)

# update existing models
models = [
Expand Down Expand Up @@ -231,3 +243,32 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
chart.schema_perm = schema_perm

session.commit()


def downgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
"""
Rename existing schema permissions to omit the catalog.
"""
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
try:
schemas = database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
)
except Exception: # pylint: disable=broad-except
schemas = get_schemas(database.database_name)

for schema in schemas:
perm = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
125 changes: 125 additions & 0 deletions tests/unit_tests/migrations/shared/catalogs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,128 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
("[my_db].[public]",),
("[my_db].[db]",),
]


def test_upgrade_catalog_perms_graceful(
mocker: MockerFixture,
session: Session,
) -> None:
"""
Test the `upgrade_catalog_perms` function when it fails to connect to the DB.
During the migration we try to connect to the analytical database to get the list of
schemas. This should fail gracefully and not raise an exception, since the database
could be offline, and the permissions can be generated later then the admin enables
catalog browsing on the database (permissions are always synced on a DB update, see
`UpdateDatabaseCommand`).
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState

engine = session.get_bind()
Database.metadata.create_all(engine)

mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
db.Session.return_value = session

mocker.patch.object(
Database,
"get_all_schema_names",
side_effect=Exception("Failed to connect to the database"),
)
mocker.patch("superset.migrations.shared.catalogs.op", session)

database = Database(
database_name="my_db",
sqlalchemy_uri="postgresql://localhost/db",
)
dataset = SqlaTable(
table_name="my_table",
database=database,
catalog=None,
schema="public",
schema_perm="[my_db].[public]",
)
session.add(dataset)
session.commit()

chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
)
query = Query(
client_id="foo",
database=database,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()

# before migration
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
]

upgrade_catalog_perms()

# after migration
assert dataset.catalog == "db"
assert query.catalog == "db"
assert saved_query.catalog == "db"
assert tab_state.catalog == "db"
assert table_schema.catalog == "db"
assert dataset.schema_perm == "[my_db].[db].[public]"
assert chart.schema_perm == "[my_db].[db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[db].[public]",),
("[my_db].[db]",),
]

downgrade_catalog_perms()

# revert
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
("[my_db].[db]",),
]
Loading