Skip to content

Commit

Permalink
feat: add UUID column to ImportMixin (#11098)
Browse files Browse the repository at this point in the history
* Add UUID column to ImportMixin

* Fix default value

* Fix lint

* Fix order of downgrade

* Add logging when downgrade fails

* Migrate position_json to contain UUIDs, and add schedule tables

* Save UUID when adding charts to dashboard

* Fix heads

* Rename migration file

* Fix dashboard serialization

* Fix migration script with Postgres

* Fix unique contraint name

* Handle UUID when exporting dashboard

* Fix Dataset PUT

* Add UUID JSON serialization

* Fix tests

* Simplify logic

* Try binary=True
  • Loading branch information
betodealmeida authored Oct 7, 2020
1 parent 6e0d1b8 commit 9785667
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 12 deletions.
11 changes: 11 additions & 0 deletions superset/dashboards/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ def set_dash_metadata(

dashboard.slices = current_slices

# add UUID to positions
uuid_map = {slice.id: str(slice.uuid) for slice in current_slices}
for obj in positions.values():
if (
isinstance(obj, dict)
and obj["type"] == "CHART"
and obj["meta"]["chartId"]
):
chart_id = obj["meta"]["chartId"]
obj["meta"]["uuid"] = uuid_map[chart_id]

# remove leading and trailing white spaces in the dumped json
dashboard.position_json = json.dumps(
positions, indent=None, separators=(",", ":"), sort_keys=True
Expand Down
1 change: 1 addition & 0 deletions superset/datasets/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class DatasetColumnsPutSchema(Schema):
python_date_format = fields.String(
allow_none=True, validate=[Length(1, 255), validate_python_date_format]
)
uuid = fields.String(allow_none=True)


class DatasetMetricsPutSchema(Schema):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add_uuid_column_to_import_mixin
Revision ID: b56500de1855
Revises: 18532d70ab98
Create Date: 2020-09-28 17:57:23.128142
"""
import json
import logging
import uuid

import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import UUIDType

from superset import db
from superset.utils import core as utils

# revision identifiers, used by Alembic.
revision = "b56500de1855"
down_revision = "18532d70ab98"


Base = declarative_base()


class ImportMixin:
id = sa.Column(sa.Integer, primary_key=True)
uuid = sa.Column(UUIDType(binary=True), primary_key=False, default=uuid.uuid4)


table_names = [
# Core models
"dbs",
"dashboards",
"slices",
# SQLAlchemy connectors
"tables",
"table_columns",
"sql_metrics",
# Druid connector
"clusters",
"datasources",
"columns",
"metrics",
# Dashboard email schedules
"dashboard_email_schedules",
"slice_email_schedules",
]
models = {
table_name: type(table_name, (Base, ImportMixin), {"__tablename__": table_name})
for table_name in table_names
}

models["dashboards"].position_json = sa.Column(utils.MediumText())


def add_uuids(objects, session, batch_size=100):
uuid_map = {}
count = len(objects)
for i, object_ in enumerate(objects):
object_.uuid = uuid.uuid4()
uuid_map[object_.id] = object_.uuid
session.merge(object_)
if (i + 1) % batch_size == 0:
session.commit()
print(f"uuid assigned to {i + 1} out of {count}")

session.commit()
print(f"Done! Assigned {count} uuids")

return uuid_map


def update_position_json(dashboard, session, uuid_map):
layout = json.loads(dashboard.position_json or "{}")
for object_ in layout.values():
if (
isinstance(object_, dict)
and object_["type"] == "CHART"
and object_["meta"]["chartId"]
):
chart_id = object_["meta"]["chartId"]
if chart_id in uuid_map:
object_["meta"]["uuid"] = str(uuid_map[chart_id])
elif object_["meta"].get("uuid"):
del object_["meta"]["uuid"]

dashboard.position_json = json.dumps(layout, indent=4)
session.merge(dashboard)
session.commit()


def upgrade():
bind = op.get_bind()
session = db.Session(bind=bind)

uuid_maps = {}
for table_name, model in models.items():
with op.batch_alter_table(table_name) as batch_op:
batch_op.add_column(
sa.Column(
"uuid",
UUIDType(binary=True),
primary_key=False,
default=uuid.uuid4,
)
)

# populate column
objects = session.query(model).all()
uuid_maps[table_name] = add_uuids(objects, session)

# add uniqueness constraint
with op.batch_alter_table(table_name) as batch_op:
batch_op.create_unique_constraint(f"uq_{table_name}_uuid", ["uuid"])

# add UUID to Dashboard.position_json
Dashboard = models["dashboards"]
for dashboard in session.query(Dashboard).all():
update_position_json(dashboard, session, uuid_maps["slices"])


def downgrade():
bind = op.get_bind()
session = db.Session(bind=bind)

# remove uuid from position_json
Dashboard = models["dashboards"]
for dashboard in session.query(Dashboard).all():
update_position_json(dashboard, session, {})

# remove uuid column
for table_name, model in models.items():
with op.batch_alter_table(model) as batch_op:
batch_op.drop_constraint(f"uq_{table_name}_uuid")
batch_op.drop_column("uuid")
6 changes: 6 additions & 0 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import logging
import re
import uuid
from datetime import datetime, timedelta
from json.decoder import JSONDecodeError
from typing import Any, Dict, List, Optional, Set, Union
Expand All @@ -35,6 +36,7 @@
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Mapper, Session
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy_utils import UUIDType

from superset.utils.core import QueryStatus

Expand All @@ -51,6 +53,10 @@ def json_to_dict(json_str: str) -> Dict[Any, Any]:


class ImportMixin:
uuid = sa.Column(
UUIDType(binary=True), primary_key=False, unique=True, default=uuid.uuid4
)

export_parent: Optional[str] = None
# The name of the attribute
# with the SQL Alchemy back reference
Expand Down
6 changes: 3 additions & 3 deletions superset/models/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ def data(self) -> Dict[str, Any]:
@property
def digest(self) -> str:
"""
Returns a MD5 HEX digest that makes this dashboard unique
Returns a MD5 HEX digest that makes this dashboard unique
"""
return utils.md5_hex(self.params)

@property
def thumbnail_url(self) -> str:
"""
Returns a thumbnail URL with a HEX digest. We want to avoid browser cache
if the dashboard has changed
Returns a thumbnail URL with a HEX digest. We want to avoid browser cache
if the dashboard has changed
"""
return f"/api/v1/chart/{self.id}/thumbnail/{self.digest}/"

Expand Down
4 changes: 3 additions & 1 deletion superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.sort_keys = True

def default(self, o: Any) -> Dict[Any, Any]:
def default(self, o: Any) -> Union[Dict[Any, Any], str]:
if isinstance(o, uuid.UUID):
return str(o)
try:
vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"}
return {"__{}__".format(o.__class__.__name__): vals}
Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ def sync_druid_source(self) -> FlaskResponse: # pylint: disable=no-self-use
@expose("/get_or_create_table/", methods=["POST"])
@event_logger.log_this
def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use
""" Gets or creates a table object with attributes passed to the API.
"""Gets or creates a table object with attributes passed to the API.
It expects the json with params:
* datasourceName - e.g. table name, required
Expand Down
2 changes: 1 addition & 1 deletion tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ def test_create_database_fail_sqllite(self):
]
}
}
self.assertEqual(response.status_code, 400)
self.assertEqual(response_data, expected_response)
self.assertEqual(response.status_code, 400)

def test_create_database_conn_fail(self):
"""
Expand Down
20 changes: 16 additions & 4 deletions tests/dict_import_export_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Unit tests for Superset"""
import json
import unittest
from uuid import uuid4

import yaml

Expand Down Expand Up @@ -64,26 +65,33 @@ def setUpClass(cls):
def tearDownClass(cls):
cls.delete_imports()

def create_table(self, name, schema="", id=0, cols_names=[], metric_names=[]):
def create_table(
self, name, schema="", id=0, cols_names=[], cols_uuids=None, metric_names=[]
):
database_name = "main"
name = "{0}{1}".format(NAME_PREFIX, name)
params = {DBREF: id, "database_name": database_name}

if cols_uuids is None:
cols_uuids = [None] * len(cols_names)

dict_rep = {
"database_id": get_example_database().id,
"table_name": name,
"schema": schema,
"id": id,
"params": json.dumps(params),
"columns": [{"column_name": c} for c in cols_names],
"columns": [
{"column_name": c, "uuid": u} for c, u in zip(cols_names, cols_uuids)
],
"metrics": [{"metric_name": c, "expression": ""} for c in metric_names],
}

table = SqlaTable(
id=id, schema=schema, table_name=name, params=json.dumps(params)
)
for col_name in cols_names:
table.columns.append(TableColumn(column_name=col_name))
for col_name, uuid in zip(cols_names, cols_uuids):
table.columns.append(TableColumn(column_name=col_name, uuid=uuid))
for metric_name in metric_names:
table.metrics.append(SqlMetric(metric_name=metric_name, expression=""))
return table, dict_rep
Expand Down Expand Up @@ -171,6 +179,7 @@ def test_import_table_1_col_1_met(self):
"table_1_col_1_met",
id=ID_PREFIX + 2,
cols_names=["col1"],
cols_uuids=[uuid4()],
metric_names=["metric1"],
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
Expand All @@ -187,6 +196,7 @@ def test_import_table_2_col_2_met(self):
"table_2_col_2_met",
id=ID_PREFIX + 3,
cols_names=["c1", "c2"],
cols_uuids=[uuid4(), uuid4()],
metric_names=["m1", "m2"],
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
Expand Down Expand Up @@ -217,6 +227,7 @@ def test_import_table_override_append(self):
id=ID_PREFIX + 3,
metric_names=["new_metric1", "m1"],
cols_names=["col1", "new_col1", "col2", "col3"],
cols_uuids=[col.uuid for col in imported_over.columns],
)
self.assert_table_equals(expected_table, imported_over)
self.yaml_compare(
Expand Down Expand Up @@ -247,6 +258,7 @@ def test_import_table_override_sync(self):
id=ID_PREFIX + 3,
metric_names=["new_metric1"],
cols_names=["new_col1", "col2", "col3"],
cols_uuids=[col.uuid for col in imported_over.columns],
)
self.assert_table_equals(expected_table, imported_over)
self.yaml_compare(
Expand Down
4 changes: 2 additions & 2 deletions tests/import_export_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def assert_slice_equals(self, expected_slc, actual_slc):
self.assertEqual(exp_params, actual_params)

def assert_only_exported_slc_fields(self, expected_dash, actual_dash):
""" only exported json has this params
imported/created dashboard has relationships to other models instead
"""only exported json has this params
imported/created dashboard has relationships to other models instead
"""
expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "")
actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "")
Expand Down

0 comments on commit 9785667

Please sign in to comment.