Skip to content

Commit

Permalink
fix: trino thread app missing full context (#29981)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored Aug 22, 2024
1 parent c049771 commit 4d821f4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
15 changes: 13 additions & 2 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import pandas as pd
import pyarrow as pa
from flask import current_app, Flask, g
from flask import ctx, current_app, Flask, g
from sqlalchemy import text
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
Expand Down Expand Up @@ -227,12 +227,22 @@ def execute_with_cursor(
execute_event = threading.Event()

def _execute(
results: dict[str, Any], event: threading.Event, app: Flask
results: dict[str, Any],
event: threading.Event,
app: Flask,
g_copy: ctx._AppCtxGlobals,
) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)

try:
# Flask contexts are local to the thread that handles the request.
# When you spawn a new thread, it does not inherit the contexts
# from the parent thread,
# meaning the g object and other context-bound variables are not
# accessible
with app.app_context():
for key, value in g_copy.__dict__.items():
setattr(g, key, value)
cls.execute(cursor, sql, query.database)
except Exception as ex: # pylint: disable=broad-except
results["error"] = ex
Expand All @@ -245,6 +255,7 @@ def _execute(
execute_result,
execute_event,
current_app._get_current_object(), # pylint: disable=protected-access
g._get_current_object(), # pylint: disable=protected-access
),
)
execute_thread.start()
Expand Down
28 changes: 28 additions & 0 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import pandas as pd
import pytest
from flask import g, has_app_context
from pytest_mock import MockerFixture
from requests.exceptions import ConnectionError as RequestsConnectionError
from sqlalchemy import sql, text, types
Expand Down Expand Up @@ -435,6 +436,33 @@ def _mock_execute(*args, **kwargs):
)


def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` still contains the current app context"""
from superset.db_engine_specs.trino import TrinoEngineSpec

mock_cursor = mocker.MagicMock()
mock_cursor.query_id = None

mock_query = mocker.MagicMock()
g.some_value = "some_value"

def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"

with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)


def test_get_columns(mocker: MockerFixture):
"""Test that ROW columns are not expanded without expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec
Expand Down

0 comments on commit 4d821f4

Please sign in to comment.