diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 1eb4b307870d0..49615c39cba52 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -27,7 +27,7 @@ import numpy as np import pandas as pd import pyarrow as pa -from flask import current_app, Flask, g +from flask import ctx, current_app, Flask, g from sqlalchemy import text from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL @@ -227,12 +227,22 @@ def execute_with_cursor( execute_event = threading.Event() def _execute( - results: dict[str, Any], event: threading.Event, app: Flask + results: dict[str, Any], + event: threading.Event, + app: Flask, + g_copy: ctx._AppCtxGlobals, ) -> None: logger.debug("Query %d: Running query: %s", query_id, sql) try: + # Flask contexts are local to the thread that handles the request. + # When you spawn a new thread, it does not inherit the contexts + # from the parent thread, + # meaning the g object and other context-bound variables are not + # accessible with app.app_context(): + for key, value in g_copy.__dict__.items(): + setattr(g, key, value) cls.execute(cursor, sql, query.database) except Exception as ex: # pylint: disable=broad-except results["error"] = ex @@ -245,6 +255,7 @@ def _execute( execute_result, execute_event, current_app._get_current_object(), # pylint: disable=protected-access + g._get_current_object(), # pylint: disable=protected-access ), ) execute_thread.start() diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index a0923e8111860..f7183ba7d8536 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -23,6 +23,7 @@ import pandas as pd import pytest +from flask import g, has_app_context from pytest_mock import MockerFixture from requests.exceptions import ConnectionError as RequestsConnectionError from sqlalchemy import sql, text, types @@ -435,6 +436,33 @@ def _mock_execute(*args, **kwargs): ) +def test_execute_with_cursor_app_context(app, mocker: MockerFixture): + """Test that `execute_with_cursor` still contains the current app context""" + from superset.db_engine_specs.trino import TrinoEngineSpec + + mock_cursor = mocker.MagicMock() + mock_cursor.query_id = None + + mock_query = mocker.MagicMock() + g.some_value = "some_value" + + def _mock_execute(*args, **kwargs): + assert has_app_context() + assert g.some_value == "some_value" + + with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute): + with patch.dict( + "superset.config.DISALLOWED_SQL_FUNCTIONS", + {}, + clear=True, + ): + TrinoEngineSpec.execute_with_cursor( + cursor=mock_cursor, + sql="SELECT 1 FROM foo", + query=mock_query, + ) + + def test_get_columns(mocker: MockerFixture): """Test that ROW columns are not expanded without expand_rows""" from superset.db_engine_specs.trino import TrinoEngineSpec