Skip to content

Commit

Permalink
fix: Remedy logic for UpdateDatasetCommand uniqueness check
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed May 4, 2024
1 parent 2e9cc65 commit 999fb95
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 26 deletions.
7 changes: 3 additions & 4 deletions superset/commands/dataset/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,16 @@ def run(self) -> Model:
def validate(self) -> None:
exceptions: list[ValidationError] = []
database_id = self._properties["database"]
table_name = self._properties["table_name"]
schema = self._properties.get("schema")
catalog = self._properties.get("catalog")
sql = self._properties.get("sql")
owner_ids: Optional[list[int]] = self._properties.get("owners")

table = Table(table_name, schema, catalog)
table = Table(self._properties["table_name"], schema, catalog)

# Validate uniqueness
if not DatasetDAO.validate_uniqueness(database_id, table):
exceptions.append(DatasetExistsValidationError(table_name))
exceptions.append(DatasetExistsValidationError(table))

# Validate/Populate database
database = DatasetDAO.get_database_by_id(database_id)
Expand All @@ -86,7 +85,7 @@ def validate(self) -> None:
and not sql
and not DatasetDAO.validate_table_exists(database, table)
):
exceptions.append(TableNotFoundValidationError(table_name))
exceptions.append(TableNotFoundValidationError(table))

if sql:
try:
Expand Down
4 changes: 2 additions & 2 deletions superset/commands/dataset/duplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from superset.exceptions import SupersetErrorException
from superset.extensions import db
from superset.models.core import Database
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, Table

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -124,7 +124,7 @@ def validate(self) -> None:
exceptions.append(DatasourceTypeInvalidError())

if DatasetDAO.find_one_or_none(table_name=duplicate_name):
exceptions.append(DatasetExistsValidationError(table_name=duplicate_name))
exceptions.append(DatasetExistsValidationError(table=Table(duplicate_name)))

try:
owners = self.populate_owners()
Expand Down
19 changes: 9 additions & 10 deletions superset/commands/dataset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
ImportFailedError,
UpdateFailedError,
)
from superset.sql_parse import Table


def get_dataset_exist_error_msg(full_name: str) -> str:
return _("Dataset %(name)s already exists", name=full_name)
def get_dataset_exist_error_msg(table: Table) -> str:
return _("Dataset %(table)s already exists", table=table)


class DatabaseNotFoundValidationError(ValidationError):
Expand All @@ -55,10 +56,8 @@ class DatasetExistsValidationError(ValidationError):
Marshmallow validation error for dataset already exists
"""

def __init__(self, table_name: str) -> None:
super().__init__(
[get_dataset_exist_error_msg(table_name)], field_name="table_name"
)
def __init__(self, table: Table) -> None:
super().__init__([get_dataset_exist_error_msg(table)], field_name="table")


class DatasetColumnNotFoundValidationError(ValidationError):
Expand Down Expand Up @@ -124,18 +123,18 @@ class TableNotFoundValidationError(ValidationError):
Marshmallow validation error when a table does not exist on the database
"""

def __init__(self, table_name: str) -> None:
def __init__(self, table: Table) -> None:
super().__init__(
[
_(
"Table [%(table_name)s] could not be found, "
"Table [%(table)s] could not be found, "
"please double check your "
"database connection, schema, and "
"table name",
table_name=table_name,
table=table,
)
],
field_name="table_name",
field_name="table",
)


Expand Down
14 changes: 10 additions & 4 deletions superset/commands/dataset/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,21 @@ def validate(self) -> None:
except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex

database_id = self._properties.get("database", None)
table_name = self._properties.get("table_name", None)
database_id = self._properties.get("database")

table = Table(
self._properties.get("table_name"), # type: ignore
self._properties.get("schema"),
self._model.catalog,
)

# Validate uniqueness
if not DatasetDAO.validate_update_uniqueness(
self._model.database_id,
Table(table_name, self._model.schema, self._model.catalog),
table,
self._model_id,
):
exceptions.append(DatasetExistsValidationError(table_name))
exceptions.append(DatasetExistsValidationError(table))
# Validate/Populate database not allowed to change
if database_id and database_id != self._model:
exceptions.append(DatabaseChangeValidationError())
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def test_create_dataset_validate_uniqueness(self):
assert rv.status_code == 422
data = json.loads(rv.data.decode("utf-8"))
assert data == {
"message": {"table_name": ["Dataset energy_usage already exists"]}
"message": {"table": ["Dataset main.energy_usage already exists"]}
}

@pytest.mark.usefixtures("load_energy_table_with_slice")
Expand All @@ -682,7 +682,7 @@ def test_create_dataset_with_sql_validate_uniqueness(self):
assert rv.status_code == 422
data = json.loads(rv.data.decode("utf-8"))
assert data == {
"message": {"table_name": ["Dataset energy_usage already exists"]}
"message": {"table": ["Dataset main.energy_usage already exists"]}
}

@pytest.mark.usefixtures("load_energy_table_with_slice")
Expand Down Expand Up @@ -1429,7 +1429,7 @@ def test_update_dataset_item_uniqueness(self):
data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
expected_response = {
"message": {"table_name": ["Dataset ab_user already exists"]}
"message": {"table": ["Dataset ab_user already exists"]}
}
assert data == expected_response
db.session.delete(dataset)
Expand Down
60 changes: 60 additions & 0 deletions tests/unit_tests/commands/dataset/test_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock

import pytest
from pytest_mock import MockFixture

from superset import db
from superset.commands.dataset.exceptions import DatasetInvalidError
from superset.commands.dataset.update import UpdateDatasetCommand
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database


@pytest.mark.usefixture("session")
def test_update_uniqueness_error(mocker: MockFixture) -> None:
SqlaTable.metadata.create_all(db.session.get_bind())
database = Database(database_name="my_db", sqlalchemy_uri="sqlite://")
bar = SqlaTable(table_name="bar", schema="foo", database=database)
baz = SqlaTable(table_name="baz", schema="qux", database=database)
db.session.add_all([database, bar, baz])
db.session.commit()

mock_g = mocker.patch("superset.security.manager.g")
mock_g.user = MagicMock()

mocker.patch(
"superset.views.base.security_manager.can_access_all_datasources",
return_value=True,
)

mocker.patch(
"superset.commands.dataset.update.security_manager.raise_for_ownership",
return_value=None,
)

mocker.patch.object(UpdateDatasetCommand, "compute_owners", return_value=[])

with pytest.raises(DatasetInvalidError):
UpdateDatasetCommand(
bar.id,
{
"table_name": "baz",
"schema": "qux",
},
).run()
4 changes: 1 addition & 3 deletions tests/unit_tests/dao/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_validate_update_uniqueness(session: Session) -> None:
db.session.add_all([database, dataset1, dataset2])
db.session.flush()

# same table name, different schema
#
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
Expand All @@ -61,7 +61,6 @@ def test_validate_update_uniqueness(session: Session) -> None:
is True
)

# duplicate schema and table name
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
Expand All @@ -71,7 +70,6 @@ def test_validate_update_uniqueness(session: Session) -> None:
is False
)

# no schema
assert (
DatasetDAO.validate_update_uniqueness(
database_id=database.id,
Expand Down

0 comments on commit 999fb95

Please sign in to comment.