Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Security manager incorrect calls #29884

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
import dataclasses
import logging
import re
from collections import defaultdict
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Callable, cast
from typing import Any, Callable, cast, Optional, Union

import dateutil.parser
import numpy as np
Expand Down Expand Up @@ -69,7 +70,7 @@
from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause

from superset import app, db, security_manager
from superset import app, db, is_feature_enabled, security_manager
from superset.commands.dataset.exceptions import DatasetNotFoundError
from superset.common.db_query_status import QueryStatus
from superset.connectors.sqla.utils import (
Expand Down Expand Up @@ -710,6 +711,56 @@
) -> BaseDatasource | None:
raise NotImplementedError()

def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:

Check warning on line 714 in superset/connectors/sqla/models.py

View check run for this annotation

Codecov / codecov/patch

superset/connectors/sqla/models.py#L714

Added line #L714 was not covered by tests
raise NotImplementedError()

def text(self, clause: str) -> TextClause:

Check warning on line 717 in superset/connectors/sqla/models.py

View check run for this annotation

Codecov / codecov/patch

superset/connectors/sqla/models.py#L717

Added line #L717 was not covered by tests
raise NotImplementedError()

def get_sqla_row_level_filters(

Check warning on line 720 in superset/connectors/sqla/models.py

View check run for this annotation

Codecov / codecov/patch

superset/connectors/sqla/models.py#L720

Added line #L720 was not covered by tests
self,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
Flask global namespace.

:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
template_processor = template_processor or self.get_template_processor()

all_filters: list[TextClause] = []
filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(filter_.clause)})"
)
if filter_.group_key:
filter_groups[filter_.group_key].append(clause)
else:
all_filters.append(clause)

if is_feature_enabled("EMBEDDED_SUPERSET"):
for rule in security_manager.get_guest_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(rule['clause'])})"
)
all_filters.append(clause)

grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
all_filters.extend(grouped_filters)
return all_filters
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in RLS filters: %(msg)s",

Check warning on line 759 in superset/connectors/sqla/models.py

View check run for this annotation

Codecov / codecov/patch

superset/connectors/sqla/models.py#L758-L759

Added lines #L758 - L759 were not covered by tests
msg=ex.message,
)
) from ex


class AnnotationDatasource(BaseDatasource):
"""Dummy object so we can query annotations using 'Viz' objects just like
Expand Down
2 changes: 1 addition & 1 deletion superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,6 +2274,6 @@ def schemas_access_for_file_upload(self, pk: int) -> Response:
# otherwise the database should have been filtered out
# in CsvToDatabaseForm
schemas_allowed_processed = security_manager.get_schemas_accessible_by_user(
database, schemas_allowed, True
database, database.get_default_catalog(), schemas_allowed, True
)
return self.response(200, schemas=schemas_allowed_processed)
9 changes: 5 additions & 4 deletions superset/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory
from superset.extensions.ssh import SSHManagerFactory
from superset.extensions.stats_logger import BaseStatsLoggerManager
from superset.security.manager import SupersetSecurityManager
from superset.utils.cache_manager import CacheManager
from superset.utils.encrypt import EncryptedFieldFactory
from superset.utils.feature_flag_manager import FeatureFlagManager
Expand Down Expand Up @@ -84,9 +85,9 @@ def get_files(bundle: str, asset_type: str = "js") -> list[str]:
return {
"js_manifest": lambda bundle: get_files(bundle, "js"),
"css_manifest": lambda bundle: get_files(bundle, "css"),
"assets_prefix": self.app.config["STATIC_ASSETS_PREFIX"]
if self.app
else "",
"assets_prefix": (
self.app.config["STATIC_ASSETS_PREFIX"] if self.app else ""
),
}

def parse_manifest_json(self) -> None:
Expand Down Expand Up @@ -132,7 +133,7 @@ def init_app(self, app: Flask) -> None:
migrate = Migrate()
profiling = ProfilingExtension()
results_backend_manager = ResultsBackendManager()
security_manager = LocalProxy(lambda: appbuilder.sm)
security_manager: SupersetSecurityManager = LocalProxy(lambda: appbuilder.sm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice this will help the IDE a lot!

ssh_manager_factory = SSHManagerFactory()
stats_logger_manager = BaseStatsLoggerManager()
talisman = Talisman()
48 changes: 6 additions & 42 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import logging
import re
import uuid
from collections import defaultdict
from collections.abc import Hashable
from datetime import datetime, timedelta
from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING, Union
Expand Down Expand Up @@ -52,7 +51,7 @@
from sqlalchemy.sql.selectable import Alias, TableClause
from sqlalchemy_utils import UUIDType

from superset import app, db, is_feature_enabled, security_manager
from superset import app, db, is_feature_enabled
from superset.advanced_data_type.types import AdvancedDataTypeResponse
from superset.common.db_query_status import QueryStatus
from superset.common.utils.time_range_utils import get_since_until_from_time_range
Expand Down Expand Up @@ -806,47 +805,12 @@

def get_sqla_row_level_filters(
self,
template_processor: Optional[BaseTemplateProcessor] = None,
template_processor: Optional[BaseTemplateProcessor] = None, # pylint: disable=unused-argument
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
Flask global namespace.

:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
template_processor = template_processor or self.get_template_processor()

all_filters: list[TextClause] = []
filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(filter_.clause)})"
)
if filter_.group_key:
filter_groups[filter_.group_key].append(clause)
else:
all_filters.append(clause)

if is_feature_enabled("EMBEDDED_SUPERSET"):
for rule in security_manager.get_guest_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(rule['clause'])})"
)
all_filters.append(clause)

grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
all_filters.extend(grouped_filters)
return all_filters
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in RLS filters: %(msg)s",
msg=ex.message,
)
) from ex
# TODO: We should refactor this mixin and remove this method
# as it exists in the BaseDatasource and is not applicable
# for datasources of type query
return []

Check warning on line 813 in superset/models/helpers.py

View check run for this annotation

Codecov / codecov/patch

superset/models/helpers.py#L813

Added line #L813 was not covered by tests

def _process_sql_expression(
self,
Expand Down
4 changes: 1 addition & 3 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2643,9 +2643,7 @@ def has_guest_access(self, dashboard: "Dashboard") -> bool:
return False

dashboards = [
r
for r in user.resources
if r["type"] == GuestTokenResourceType.DASHBOARD.value
r for r in user.resources if r["type"] == GuestTokenResourceType.DASHBOARD
]

# TODO (embedded): remove this check once uuids are rolled out
Expand Down
36 changes: 32 additions & 4 deletions tests/integration_tests/security/guest_token_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,29 @@ def setUp(self) -> None:
self.authorized_guest = security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [{"type": "dashboard", "id": str(self.embedded.uuid)}],
"resources": [
{
"type": GuestTokenResourceType.DASHBOARD,
"id": str(self.embedded.uuid),
}
],
"iat": 10,
"exp": 20,
"rls_rules": [],
}
)
self.unauthorized_guest = security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [
{"type": "dashboard", "id": "06383667-3e02-4e5e-843f-44e9c5896b6c"}
{
"type": GuestTokenResourceType.DASHBOARD,
"id": "06383667-3e02-4e5e-843f-44e9c5896b6c",
}
],
"iat": 10,
"exp": 20,
"rls_rules": [],
}
)

Expand Down Expand Up @@ -247,15 +261,29 @@ def setUp(self) -> None:
self.authorized_guest = security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [{"type": "dashboard", "id": str(self.embedded.uuid)}],
"resources": [
{
"type": GuestTokenResourceType.DASHBOARD,
"id": str(self.embedded.uuid),
}
],
"iat": 10,
"exp": 20,
"rls_rules": [],
}
)
self.unauthorized_guest = security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [
{"type": "dashboard", "id": "06383667-3e02-4e5e-843f-44e9c5896b6c"}
{
"type": GuestTokenResourceType.DASHBOARD,
"id": "06383667-3e02-4e5e-843f-44e9c5896b6c",
}
],
"iat": 10,
"exp": 20,
"rls_rules": [],
}
)
self.chart = self.get_slice("Girls")
Expand Down
9 changes: 8 additions & 1 deletion tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,15 @@ def guest_user_with_rls(self, rules: Optional[list[Any]] = None) -> GuestUser:
return security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [{"type": GuestTokenResourceType.DASHBOARD.value}],
"resources": [
{
"type": GuestTokenResourceType.DASHBOARD,
"id": "06383667-3e02-4e5e-843f-44e9c5896b6c",
}
],
"rls_rules": rules,
"iat": 10,
"exp": 20,
}
)

Expand Down
42 changes: 28 additions & 14 deletions tests/unit_tests/charts/commands/importers/v1/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def test_import_chart(mocker: MockerFixture, session_with_schema: Session) -> No
Test importing a chart.
"""

mocker.patch.object(security_manager, "can_access", return_value=True)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=True
)

config = copy.deepcopy(chart_config)
config["datasource_id"] = 1
Expand All @@ -89,7 +91,7 @@ def test_import_chart(mocker: MockerFixture, session_with_schema: Session) -> No
assert chart.external_url is None

# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
mock_can_access.assert_called_once_with("can_write", "Chart")


def test_import_chart_managed_externally(
Expand All @@ -98,7 +100,9 @@ def test_import_chart_managed_externally(
"""
Test importing a chart that is managed externally.
"""
mocker.patch.object(security_manager, "can_access", return_value=True)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=True
)

config = copy.deepcopy(chart_config)
config["datasource_id"] = 1
Expand All @@ -111,7 +115,7 @@ def test_import_chart_managed_externally(
assert chart.external_url == "https://example.org/my_chart"

# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
mock_can_access.assert_called_once_with("can_write", "Chart")


def test_import_chart_without_permission(
Expand All @@ -121,7 +125,9 @@ def test_import_chart_without_permission(
"""
Test importing a chart when a user doesn't have permissions to create.
"""
mocker.patch.object(security_manager, "can_access", return_value=False)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=False
)

config = copy.deepcopy(chart_config)
config["datasource_id"] = 1
Expand All @@ -134,7 +140,7 @@ def test_import_chart_without_permission(
== "Chart doesn't exist and user doesn't have permission to create charts"
)
# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
mock_can_access.assert_called_once_with("can_write", "Chart")


def test_filter_chart_annotations(session: Session) -> None:
Expand Down Expand Up @@ -162,8 +168,12 @@ def test_import_existing_chart_without_permission(
"""
Test importing a chart when a user doesn't have permissions to modify.
"""
mocker.patch.object(security_manager, "can_access", return_value=True)
mocker.patch.object(security_manager, "can_access_chart", return_value=False)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=True
)
mock_can_access_chart = mocker.patch.object(
security_manager, "can_access_chart", return_value=False
)

slice = (
session_with_data.query(Slice)
Expand All @@ -180,8 +190,8 @@ def test_import_existing_chart_without_permission(
)

# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
security_manager.can_access_chart.assert_called_once_with(slice)
mock_can_access.assert_called_once_with("can_write", "Chart")
mock_can_access_chart.assert_called_once_with(slice)


def test_import_existing_chart_with_permission(
Expand All @@ -191,8 +201,12 @@ def test_import_existing_chart_with_permission(
"""
Test importing a chart that exists when a user has access permission to that chart.
"""
mocker.patch.object(security_manager, "can_access", return_value=True)
mocker.patch.object(security_manager, "can_access_chart", return_value=True)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=True
)
mock_can_access_chart = mocker.patch.object(
security_manager, "can_access_chart", return_value=True
)

admin = User(
first_name="Alice",
Expand All @@ -215,5 +229,5 @@ def test_import_existing_chart_with_permission(
with override_user(admin):
import_chart(config, overwrite=True)
# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
security_manager.can_access_chart.assert_called_once_with(slice)
mock_can_access.assert_called_once_with("can_write", "Chart")
mock_can_access_chart.assert_called_once_with(slice)
Loading
Loading