Skip to content

Commit

Permalink
fix: remove unneeded complexity in migration (#19022)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Mar 4, 2022
1 parent 77063cc commit 50bb86d
Showing 1 changed file with 7 additions and 89 deletions.
96 changes: 7 additions & 89 deletions superset/migrations/versions/b8d3a24d9131_new_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,22 @@
"""

import json
from typing import Any, Dict, List, Optional, Type
from typing import List
from uuid import uuid4

import sqlalchemy as sa
from alembic import op
from sqlalchemy import and_, inspect, or_
from sqlalchemy.engine import create_engine, Engine
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.exc import ArgumentError
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy_utils import UUIDType

from superset import app, db, db_engine_specs
from superset import app, db
from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES
from superset.extensions import encrypted_field_factory, security_manager
from superset.extensions import encrypted_field_factory
from superset.sql_parse import ParsedQuery
from superset.utils.memoized import memoized

# revision identifiers, used by Alembic.
revision = "b8d3a24d9131"
Expand Down Expand Up @@ -78,86 +75,6 @@ class Database(Base):
)
server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)

@property
def sqlalchemy_uri_decrypted(self) -> str:
try:
url = make_url(self.sqlalchemy_uri)
except (ArgumentError, ValueError):
return "dialect://invalid_uri"
if custom_password_store:
url.password = custom_password_store(url)
else:
url.password = self.password
return str(url)

@property
def backend(self) -> str:
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
return sqlalchemy_url.get_backend_name() # pylint: disable=no-member

@classmethod
@memoized
def get_db_engine_spec_for_backend(
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
engines = db_engine_specs.get_engine_specs()
return engines.get(backend, db_engine_specs.BaseEngineSpec)

@property
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
return self.get_db_engine_spec_for_backend(self.backend)

def get_extra(self) -> Dict[str, Any]:
return self.db_engine_spec.get_extra_params(self)

def get_effective_user(
self, object_url: URL, user_name: Optional[str] = None,
) -> Optional[str]:
effective_username = None
if self.impersonate_user:
effective_username = object_url.username
if user_name:
effective_username = user_name

return effective_username

def get_encrypted_extra(self) -> Dict[str, Any]:
return json.loads(self.encrypted_extra) if self.encrypted_extra else {}

@memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
def get_sqla_engine(self, schema: Optional[str] = None) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
effective_username = self.get_effective_user(sqlalchemy_url, "admin")
# If using MySQL or Presto for example, will set url.username
self.db_engine_spec.modify_url_for_impersonation(
sqlalchemy_url, self.impersonate_user, effective_username
)

params = extra.get("engine_params", {})
connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
)

if connect_args:
params["connect_args"] = connect_args

params.update(self.get_encrypted_extra())

if DB_CONNECTION_MUTATOR:
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
sqlalchemy_url,
params,
effective_username,
security_manager,
"migration",
)

return create_engine(sqlalchemy_url, **params)


class TableColumn(Base):

Expand Down Expand Up @@ -325,8 +242,9 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
)
if not database:
return
engine = database.get_sqla_engine(schema=target.schema)
conditional_quote = engine.dialect.identifier_preparer.quote
url = make_url(database.sqlalchemy_uri)
dialect_class = url.get_dialect()
conditional_quote = dialect_class().identifier_preparer.quote

# create columns
columns = []
Expand Down

0 comments on commit 50bb86d

Please sign in to comment.