Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add new config to allow for specific import data urls #22942

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,13 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# Prevents unsafe default endpoints to be registered on datasets.
PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = True

# Define a list of allowed URLs for dataset data imports (v1).
# Simple example to only allow URLs that belong to certain domains:
# ALLOWED_IMPORT_URL_DOMAINS = [
# r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"
# ]
DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"]

# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
Expand Down
4 changes: 4 additions & 0 deletions superset/datasets/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,7 @@ class DatasetAccessDeniedError(ForbiddenError):

class DatasetDuplicateFailedError(CreateFailedError):
message = _("Dataset could not be duplicated.")


class DatasetForbiddenDataURI(ForbiddenError):
message = _("Data URI is not allowed.")
32 changes: 31 additions & 1 deletion superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sqlalchemy.sql.visitors import VisitableType

from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.exceptions import DatasetForbiddenDataURI
from superset.models.core import Database

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,6 +76,28 @@ def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]:
}


def validate_data_uri(data_uri: str) -> None:
"""
Validate that the data URI is configured on DATASET_IMPORT_ALLOWED_URLS
has a valid URL.

:param data_uri:
:return:
"""
allowed_urls = current_app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"]
for allowed_url in allowed_urls:
try:
match = re.match(allowed_url, data_uri)
except re.error:
logger.exception(
"Invalid regular expression on DATASET_IMPORT_ALLOWED_URLS"
)
raise
if match:
return
raise DatasetForbiddenDataURI()


def import_dataset(
session: Session,
config: Dict[str, Any],
Expand Down Expand Up @@ -139,7 +162,6 @@ def import_dataset(
table_exists = True

if data_uri and (not table_exists or force_data):
logger.info("Downloading data from %s", data_uri)
load_data(data_uri, dataset, dataset.database, session)

if hasattr(g, "user") and g.user:
Expand All @@ -151,6 +173,14 @@ def import_dataset(
def load_data(
data_uri: str, dataset: SqlaTable, database: Database, session: Session
) -> None:
"""
Load data from a data URI into a dataset.

:raises DatasetUnAllowedDataURI: If a dataset is trying
to load data from a URI that is not allowed.
"""
validate_data_uri(data_uri)
logger.info("Downloading data from %s", data_uri)
data = request.urlopen(data_uri) # pylint: disable=consider-using-with
if data_uri.endswith(".gz"):
data = gzip.open(data)
Expand Down
128 changes: 123 additions & 5 deletions tests/unit_tests/datasets/commands/importers/v1/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@

import copy
import json
import re
import uuid
from typing import Any, Dict
from unittest.mock import Mock, patch

import pytest
from flask import current_app
from sqlalchemy.orm.session import Session

from superset.datasets.commands.exceptions import DatasetForbiddenDataURI
from superset.datasets.commands.importers.v1.utils import validate_data_uri


def test_import_dataset(session: Session) -> None:
"""
Test importing a dataset.
"""
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.core import Database

engine = session.get_bind()
Expand Down Expand Up @@ -340,13 +346,85 @@ def test_import_column_extra_is_string(session: Session) -> None:
assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'


@patch("superset.datasets.commands.importers.v1.utils.request")
def test_import_column_allowed_data_url(request: Mock, session: Session) -> None:
"""
Test importing a dataset when using data key to fetch data from a URL.
"""
import io

from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.core import Database

request.urlopen.return_value = io.StringIO("col1\nvalue1\nvalue2\n")

engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member

database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
session.add(database)
session.flush()

dataset_uuid = uuid.uuid4()
yaml_config: Dict[str, Any] = {
"version": "1.0.0",
"table_name": "my_table",
"main_dttm_col": "ds",
"description": "This is the description",
"default_endpoint": None,
"offset": -8,
"cache_timeout": 3600,
"schema": None,
"sql": None,
"params": {
"remote_id": 64,
"database_name": "examples",
"import_time": 1606677834,
},
"template_params": None,
"filter_select_enabled": True,
"fetch_values_predicate": None,
"extra": None,
"uuid": dataset_uuid,
"metrics": [],
"columns": [
{
"column_name": "col1",
"verbose_name": None,
"is_dttm": False,
"is_active": True,
"type": "TEXT",
"groupby": False,
"filterable": False,
"expression": None,
"description": None,
"python_date_format": None,
"extra": None,
}
],
"database_uuid": database.uuid,
"data": "https://some-external-url.com/data.csv",
}

# the Marshmallow schema should convert strings to objects
schema = ImportV1DatasetSchema()
dataset_config = schema.load(yaml_config)
dataset_config["database_id"] = database.id
_ = import_dataset(session, dataset_config, force_data=True)
session.connection()
assert [("value1",), ("value2",)] == session.execute(
"SELECT * FROM my_table"
).fetchall()


def test_import_dataset_managed_externally(session: Session) -> None:
"""
Test importing a dataset that is managed externally.
"""
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.core import Database
from tests.integration_tests.fixtures.importexport import dataset_config

Expand All @@ -357,7 +435,6 @@ def test_import_dataset_managed_externally(session: Session) -> None:
session.add(database)
session.flush()

dataset_uuid = uuid.uuid4()
config = copy.deepcopy(dataset_config)
config["is_managed_externally"] = True
config["external_url"] = "https://example.org/my_table"
Expand All @@ -366,3 +443,44 @@ def test_import_dataset_managed_externally(session: Session) -> None:
sqla_table = import_dataset(session, config)
assert sqla_table.is_managed_externally is True
assert sqla_table.external_url == "https://example.org/my_table"


@pytest.mark.parametrize(
"allowed_urls, data_uri, expected, exception_class",
[
([r".*"], "https://some-url/data.csv", True, None),
(
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
"https://host1.domain1.com/data.csv",
True,
None,
),
(
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
"https://host2.domain1.com/data.csv",
True,
None,
),
(
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
"https://host1.domain2.com/data.csv",
True,
None,
),
(
[r"^https://.+\.domain1\.com\/?.*", r"^https://.+\.domain2\.com\/?.*"],
"https://host1.domain3.com/data.csv",
False,
DatasetForbiddenDataURI,
),
([], "https://host1.domain3.com/data.csv", False, DatasetForbiddenDataURI),
(["*"], "https://host1.domain3.com/data.csv", False, re.error),
],
)
def test_validate_data_uri(allowed_urls, data_uri, expected, exception_class):
current_app.config["DATASET_IMPORT_ALLOWED_DATA_URLS"] = allowed_urls
if expected:
validate_data_uri(data_uri)
else:
with pytest.raises(exception_class):
validate_data_uri(data_uri)