Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: write new dataset on update table if it doesn't exist #19269

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 150 additions & 122 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,11 +1863,20 @@ def update_table( # pylint: disable=unused-argument

session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))

# update ``Column`` model as well
dataset = (
session.query(NewDataset).filter_by(sqlatable_id=target.table.id).one()
session.query(NewDataset)
.filter_by(sqlatable_id=target.table.id)
.one_or_none()
)

if not dataset:
# if dataset is not found create a new copy
# of the dataset instead of updating the existing

SqlaTable.write_shadow_dataset(target.table, database, session)
return

# update ``Column`` model as well
if isinstance(target, TableColumn):
columns = [
column
Expand Down Expand Up @@ -1923,7 +1932,7 @@ def update_table( # pylint: disable=unused-argument
column.extra_json = json.dumps(extra_json) if extra_json else None

@staticmethod
def after_insert( # pylint: disable=too-many-locals
def after_insert(
mapper: Mapper, connection: Connection, target: "SqlaTable",
) -> None:
"""
Expand All @@ -1938,135 +1947,18 @@ def after_insert( # pylint: disable=too-many-locals

For more context: https://github.com/apache/superset/issues/14909
"""
session = inspect(target).session
# set permissions
security_manager.set_perm(mapper, connection, target)

session = inspect(target).session

# get DB-specific conditional quoter for expressions that point to columns or
# table names
database = (
target.database
or session.query(Database).filter_by(id=target.database_id).one()
)
engine = database.get_sqla_engine(schema=target.schema)
conditional_quote = engine.dialect.identifier_preparer.quote

# create columns
columns = []
for column in target.columns:
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
if column.is_active is False:
continue

extra_json = json.loads(column.extra or "{}")
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
extra_json[attr] = value

columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression
or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
),
)

# create metrics
for metric in target.metrics:
extra_json = json.loads(metric.extra or "{}")
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value

is_additive = (
metric.metric_type
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)

columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown", # figuring this out would require a type inferrer
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
),
)

# physical dataset
tables = []
if target.sql is None:
physical_columns = [column for column in columns if column.is_physical]

# create table
table = NewTable(
name=target.table_name,
schema=target.schema,
catalog=None, # currently not supported
database_id=target.database_id,
columns=physical_columns,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
tables.append(table)

# virtual dataset
else:
# mark all columns as virtual (not physical)
for column in columns:
column.is_physical = False

# find referenced tables
parsed = ParsedQuery(target.sql)
referenced_tables = parsed.tables

# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or target.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
)
tables = session.query(NewTable).filter(predicate).all()

# create the new dataset
dataset = NewDataset(
sqlatable_id=target.id,
name=target.table_name,
expression=target.sql or conditional_quote(target.table_name),
tables=tables,
columns=columns,
is_physical=target.sql is None,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
session.add(dataset)
SqlaTable.write_shadow_dataset(target, database, session)

@staticmethod
def after_delete( # pylint: disable=unused-argument
Expand Down Expand Up @@ -2301,6 +2193,142 @@ def after_update( # pylint: disable=too-many-branches, too-many-locals, too-man
dataset.expression = target.sql or conditional_quote(target.table_name)
dataset.is_physical = target.sql is None

@staticmethod
def write_shadow_dataset( # pylint: disable=too-many-locals
dataset: "SqlaTable", database: Database, session: Session
) -> None:
"""
Shadow write the dataset to new models.

The ``SqlaTable`` model is currently being migrated to two new models, ``Table``
and ``Dataset``. In the first phase of the migration the new models are populated
whenever ``SqlaTable`` is modified (created, updated, or deleted).

In the second phase of the migration reads will be done from the new models.
Finally, in the third phase of the migration the old models will be removed.

For more context: https://github.com/apache/superset/issues/14909
"""

engine = database.get_sqla_engine(schema=dataset.schema)
conditional_quote = engine.dialect.identifier_preparer.quote

# create columns
columns = []
for column in dataset.columns:
# ``is_active`` might be ``None`` at this point, but it defaults to ``True``.
if column.is_active is False:
continue

extra_json = json.loads(column.extra or "{}")
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
extra_json[attr] = value

columns.append(
NewColumn(
name=column.column_name,
type=column.type or "Unknown",
expression=column.expression
or conditional_quote(column.column_name),
description=column.description,
is_temporal=column.is_dttm,
is_aggregation=False,
is_physical=column.expression is None,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
),
)

# create metrics
for metric in dataset.metrics:
extra_json = json.loads(metric.extra or "{}")
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
extra_json[attr] = value

is_additive = (
metric.metric_type
and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES
)

columns.append(
NewColumn(
name=metric.metric_name,
type="Unknown", # figuring this out would require a type inferrer
expression=metric.expression,
warning_text=metric.warning_text,
description=metric.description,
is_aggregation=True,
is_additive=is_additive,
is_physical=False,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
extra_json=json.dumps(extra_json) if extra_json else None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
),
)

# physical dataset
tables = []
if dataset.sql is None:
physical_columns = [column for column in columns if column.is_physical]

# create table
table = NewTable(
name=dataset.table_name,
schema=dataset.schema,
catalog=None, # currently not supported
database_id=dataset.database_id,
columns=physical_columns,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
tables.append(table)

# virtual dataset
else:
# mark all columns as virtual (not physical)
for column in columns:
column.is_physical = False

# find referenced tables
parsed = ParsedQuery(dataset.sql)
referenced_tables = parsed.tables

# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or dataset.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
)
tables = session.query(NewTable).filter(predicate).all()

# create the new dataset
new_dataset = NewDataset(
sqlatable_id=dataset.id,
name=dataset.table_name,
expression=dataset.sql or conditional_quote(dataset.table_name),
tables=tables,
columns=columns,
is_physical=dataset.sql is None,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
session.add(new_dataset)


sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert)
Expand Down
Loading