Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion superset/tasks/celery_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from superset import create_app
from superset.extensions import celery_app, db

from flask import has_app_context

# Init the Flask app / configure everything
flask_app = create_app()

Expand Down Expand Up @@ -68,4 +70,6 @@ def teardown( # pylint: disable=unused-argument
db.session.commit() # pylint: disable=consider-using-transaction

if not flask_app.config.get("CELERY_ALWAYS_EAGER"):
db.session.remove()
# Ensure session is removed only inside flask app context
if has_app_context():
db.session.remove()
28 changes: 28 additions & 0 deletions tests/integration_tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,34 @@ def my_task(self):
)


def test_teardown_without_app_context():
"""Test teardown skips db.session.remove() outside app context.

Regression test for https://github.com/apache/superset/issues/36892
The task_postrun signal can fire after the app context is torn down,
so teardown() must check has_app_context() before calling db.session.remove().
"""
from superset.tasks.celery_app import teardown

assert not has_app_context(), "Test must run outside of app context"

with mock.patch("superset.tasks.celery_app.db.session.remove") as mock_remove:
teardown(retval="success")
mock_remove.assert_not_called()


def test_teardown_with_app_context():
"""Test teardown calls db.session.remove() inside app context."""
from superset.tasks.celery_app import teardown, flask_app

with flask_app.app_context():
assert has_app_context(), "Test must run inside app context"

with mock.patch("superset.tasks.celery_app.db.session.remove") as mock_remove:
teardown(retval="success")
mock_remove.assert_called_once()


def delete_tmp_view_or_table(name: str, ctas_method: CTASMethod):
db.get_engine().execute(f"DROP {ctas_method.name} IF EXISTS {name}")

Expand Down