From f314685a8e2c45b6bd2c6d1f01653d19133b9a5f Mon Sep 17 00:00:00 2001 From: Geido <60598000+geido@users.noreply.github.com> Date: Fri, 4 Oct 2024 18:12:28 +0300 Subject: [PATCH] fix(Explore): Apply RLS at column values (#30490) Co-authored-by: Beto Dealmeida --- superset/models/helpers.py | 5 +- .../integration_tests/datasource/api_tests.py | 29 ++++++++++ tests/integration_tests/sqla_models_tests.py | 26 +++++++++ tests/unit_tests/models/helpers_test.py | 53 +++++++++++++++++++ 4 files changed, 112 insertions(+), 1 deletion(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 850a6e259f1b5..4085d3a0aabc7 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1309,7 +1309,7 @@ def get_time_filter( # pylint: disable=too-many-arguments ) return and_(*l) - def values_for_column( + def values_for_column( # pylint: disable=too-many-locals self, column_name: str, limit: int = 10000, @@ -1345,6 +1345,9 @@ def values_for_column( if self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate(template_processor=tp)) + rls_filters = self.get_sqla_row_level_filters(template_processor=tp) + qry = qry.where(and_(*rls_filters)) + with self.database.get_sqla_engine() as engine: sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) sql = self._apply_cte(sql, cte) diff --git a/tests/integration_tests/datasource/api_tests.py b/tests/integration_tests/datasource/api_tests.py index d9f3650793f39..e810e02ee5716 100644 --- a/tests/integration_tests/datasource/api_tests.py +++ b/tests/integration_tests/datasource/api_tests.py @@ -18,6 +18,7 @@ from unittest.mock import ANY, patch import pytest +from sqlalchemy.sql.elements import TextClause from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable @@ -176,3 +177,31 @@ def test_get_column_values_denormalize_column(self, denormalize_name_mock): table.normalize_columns = False self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/") # noqa: F841 denormalize_name_mock.assert_called_with(ANY, "col2") + + @pytest.mark.usefixtures("app_context", "virtual_dataset") + def test_get_column_values_with_rls(self): + self.login(ADMIN_USERNAME) + table = self.get_virtual_dataset() + with patch.object( + table, "get_sqla_row_level_filters", return_value=[TextClause("col2 = 'b'")] + ): + rv = self.client.get( + f"api/v1/datasource/table/{table.id}/column/col2/values/" + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["result"], ["b"]) + + @pytest.mark.usefixtures("app_context", "virtual_dataset") + def test_get_column_values_with_rls_no_values(self): + self.login(ADMIN_USERNAME) + table = self.get_virtual_dataset() + with patch.object( + table, "get_sqla_row_level_filters", return_value=[TextClause("col2 = 'q'")] + ): + rv = self.client.get( + f"api/v1/datasource/table/{table.id}/column/col2/values/" + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["result"], []) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index ca08842ebea13..1b5245568813c 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -626,6 +626,32 @@ def test_values_for_column_on_text_column(text_column_table): assert len(with_null) == 8 +def test_values_for_column_on_text_column_with_rls(text_column_table): + with patch.object( + text_column_table, + "get_sqla_row_level_filters", + return_value=[ + TextClause("foo = 'foo'"), + ], + ): + with_rls = text_column_table.values_for_column(column_name="foo", limit=10000) + assert with_rls == ["foo"] + assert len(with_rls) == 1 + + +def test_values_for_column_on_text_column_with_rls_no_values(text_column_table): + with patch.object( + text_column_table, + "get_sqla_row_level_filters", + return_value=[ + TextClause("foo = 'bar'"), + ], + ): + with_rls = text_column_table.values_for_column(column_name="foo", limit=10000) + assert with_rls == [] + assert len(with_rls) == 0 + + def test_filter_on_text_column(text_column_table): table = text_column_table # null value should be replaced diff --git a/tests/unit_tests/models/helpers_test.py b/tests/unit_tests/models/helpers_test.py index 009cff0adf4c5..c87b217928047 100644 --- a/tests/unit_tests/models/helpers_test.py +++ b/tests/unit_tests/models/helpers_test.py @@ -21,6 +21,7 @@ from contextlib import contextmanager from typing import TYPE_CHECKING +from unittest.mock import patch import pytest from pytest_mock import MockerFixture @@ -85,6 +86,58 @@ def test_values_for_column(database: Database) -> None: assert table.values_for_column("a") == [1, None] +def test_values_for_column_with_rls(database: Database) -> None: + """ + Test the `values_for_column` method with RLS enabled. + """ + from sqlalchemy.sql.elements import TextClause + + from superset.connectors.sqla.models import SqlaTable, TableColumn + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + columns=[ + TableColumn(column_name="a"), + ], + ) + with patch.object( + table, + "get_sqla_row_level_filters", + return_value=[ + TextClause("a = 1"), + ], + ): + assert table.values_for_column("a") == [1] + + +def test_values_for_column_with_rls_no_values(database: Database) -> None: + """ + Test the `values_for_column` method with RLS enabled and no values. + """ + from sqlalchemy.sql.elements import TextClause + + from superset.connectors.sqla.models import SqlaTable, TableColumn + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + columns=[ + TableColumn(column_name="a"), + ], + ) + with patch.object( + table, + "get_sqla_row_level_filters", + return_value=[ + TextClause("a = 2"), + ], + ): + assert table.values_for_column("a") == [] + + def test_values_for_column_calculated( mocker: MockerFixture, database: Database,