Skip to content

Commit

Permalink
fix: [alert] should run alert query from report account (apache#17499)
Browse files Browse the repository at this point in the history
* fix: [alert] should run alert query from report account

* add solution2: override username for get_df

* add integration test
  • Loading branch information
Grace Guo authored and shcoderAlex committed Feb 7, 2022
1 parent cbc861c commit 38671ee
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
5 changes: 3 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,12 @@ def get_df( # pylint: disable=too-many-locals
sql: str,
schema: Optional[str] = None,
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
username: Optional[str] = None,
) -> pd.DataFrame:
sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]

engine = self.get_sqla_engine(schema=schema)
username = utils.get_username()
engine = self.get_sqla_engine(schema=schema, user_name=username)
username = utils.get_username() or username

def needs_conversion(df_series: pd.Series) -> bool:
return (
Expand Down
7 changes: 5 additions & 2 deletions superset/reports/commands/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from celery.exceptions import SoftTimeLimitExceeded
from flask_babel import lazy_gettext as _

from superset import jinja_context
from superset import app, jinja_context
from superset.commands.base import BaseCommand
from superset.models.reports import ReportSchedule, ReportScheduleValidatorType
from superset.reports.commands.exceptions import (
Expand Down Expand Up @@ -146,8 +146,11 @@ def _execute_query(self) -> pd.DataFrame:
limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql(
rendered_sql, ALERT_SQL_LIMIT
)
query_username = app.config["THUMBNAIL_SELENIUM_USER"]
start = default_timer()
df = self._report_schedule.database.get_df(limited_rendered_sql)
df = self._report_schedule.database.get_df(
sql=limited_rendered_sql, username=query_username
)
stop = default_timer()
logger.info(
"Query for %s took %.2f ms",
Expand Down
12 changes: 12 additions & 0 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,18 @@ def test_multi_statement(self):
df = main_db.get_df("USE superset; SELECT ';';", None)
self.assertEqual(df.iat[0, 0], ";")

@mock.patch("superset.models.core.Database.get_sqla_engine")
def test_username_param(self, mocked_get_sqla_engine):
main_db = get_example_database()
main_db.impersonate_user = True
test_username = "test_username_param"

if main_db.backend == "mysql":
main_db.get_df("USE superset; SELECT 1", username=test_username)
mocked_get_sqla_engine.assert_called_with(
schema=None, user_name="test_username_param",
)

@mock.patch("superset.models.core.create_engine")
def test_get_sqla_engine(self, mocked_create_engine):
model = Database(
Expand Down

0 comments on commit 38671ee

Please sign in to comment.