From 3ececb2c79307bd943aad116d7b0b5933af8185a Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 9 Jan 2023 12:27:44 -0800 Subject: [PATCH] Propagate logs to stdout when in k8s executor pod (#28440) This is necessary because file task handler reads from pod logs (i.e. stdout) when pod is running. Previously we were not propagating task logs from `airflow.task`, presumably to avoid duplicating entries, because we had copied the handler to root. However, if we just remove the handler from task, we can safely enable propagation here, since there won't be multiple task handlers floating around. --- airflow/cli/commands/task_command.py | 124 +++++++-- airflow/settings.py | 2 + airflow/utils/log/file_task_handler.py | 35 ++- airflow/utils/log/logging_mixin.py | 13 +- tests/cli/commands/test_task_command.py | 109 ++++++++ .../task_runner/test_standard_task_runner.py | 248 +++++++++--------- 6 files changed, 362 insertions(+), 169 deletions(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 1824bc722ac3bb..0f3f5a8bba6cac 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -23,6 +23,7 @@ import json import logging import os +import sys import textwrap from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress from typing import Generator, Union @@ -43,6 +44,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.operator import needs_expansion +from airflow.settings import IS_K8S_EXECUTOR_POD from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS from airflow.typing_compat import Literal @@ -283,38 +285,52 @@ def _extract_external_executor_id(args) -> str | None: @contextmanager -def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]: +def _move_task_handlers_to_root(ti: TaskInstance) -> Generator[None, None, None]: """ - Manage logging context for a task run. + Move handlers for task logging to root logger. - - Replace the root logger configuration with the airflow.task configuration - so we can capture logs from any custom loggers used in the task. + We want anything logged during task run to be propagated to task log handlers. + If running in a k8s executor pod, also keep the stream handler on root logger + so that logs are still emitted to stdout. + """ + # nothing to do + if not ti.log.handlers or settings.DONOT_MODIFY_HANDLERS: + yield + return + + # Move task handlers to root and reset task logger and restore original logger settings after exit. + # If k8s executor, we need to ensure that root logger has a console handler, so that + # task logs propagate to stdout (this is how webserver retrieves them while task is running). + root_logger = logging.getLogger() + console_handler = next((h for h in root_logger.handlers if h.name == "console"), None) + with LoggerMutationHelper(root_logger), LoggerMutationHelper(ti.log) as task_helper: + task_helper.move(root_logger) + if IS_K8S_EXECUTOR_POD: + if console_handler and console_handler not in root_logger.handlers: + root_logger.addHandler(console_handler) + yield - - Redirect stdout and stderr to the task instance log, as INFO and WARNING - level messages, respectively. +@contextmanager +def _redirect_stdout_to_ti_log(ti: TaskInstance) -> Generator[None, None, None]: """ - modify = not settings.DONOT_MODIFY_HANDLERS - if modify: - root_logger, task_logger = logging.getLogger(), logging.getLogger("airflow.task") + Redirect stdout to ti logger. - orig_level = root_logger.level - root_logger.setLevel(task_logger.level) - orig_handlers = root_logger.handlers.copy() - root_logger.handlers[:] = task_logger.handlers + Redirect stdout and stderr to the task instance log as INFO and WARNING + level messages, respectively. - try: + If stdout already redirected (possible when task running with option + `--local`), don't redirect again. + """ + # if sys.stdout is StreamLogWriter, it means we already redirected + # likely before forking in LocalTaskJob + if not isinstance(sys.stdout, StreamLogWriter): info_writer = StreamLogWriter(ti.log, logging.INFO) warning_writer = StreamLogWriter(ti.log, logging.WARNING) - with redirect_stdout(info_writer), redirect_stderr(warning_writer): yield - - finally: - if modify: - # Restore the root logger to its original state. - root_logger.setLevel(orig_level) - root_logger.handlers[:] = orig_handlers + else: + yield class TaskCommandMarker: @@ -366,12 +382,6 @@ def task_run(args, dag=None): settings.MASK_SECRETS_IN_LOGS = True - # IMPORTANT, have to re-configure ORM with the NullPool, otherwise, each "run" command may leave - # behind multiple open sleeping connections while heartbeating, which could - # easily exceed the database connection limit when - # processing hundreds of simultaneous tasks. - settings.reconfigure_orm(disable_connection_pool=True) - get_listener_manager().hook.on_starting(component=TaskCommandMarker()) if args.pickle: @@ -390,11 +400,19 @@ def task_run(args, dag=None): log.info("Running %s on host %s", ti, hostname) + # IMPORTANT, have to re-configure ORM with the NullPool, otherwise, each "run" command may leave + # behind multiple open sleeping connections while heartbeating, which could + # easily exceed the database connection limit when + # processing hundreds of simultaneous tasks. + # this should be last thing before running, to reduce likelihood of an open session + # which can cause trouble if running process in a fork. + settings.reconfigure_orm(disable_connection_pool=True) + try: if args.interactive: _run_task_by_selected_method(args, dag, ti) else: - with _capture_task_logs(ti): + with _move_task_handlers_to_root(ti), _redirect_stdout_to_ti_log(ti): _run_task_by_selected_method(args, dag, ti) finally: try: @@ -644,3 +662,53 @@ def task_clear(args): include_subdags=not args.exclude_subdags, include_parentdag=not args.exclude_parentdag, ) + + +class LoggerMutationHelper: + """ + Helper for moving and resetting handlers and other logger attrs. + + :meta private: + """ + + def __init__(self, logger): + self.handlers = logger.handlers[:] + self.level = logger.level + self.propagate = logger.propagate + self.source_logger = logger + + def apply(self, logger, replace=True): + """ + Set ``logger`` with attrs stored on instance. + + If ``logger`` is root logger, don't change propagate. + """ + if replace: + logger.handlers[:] = self.handlers + else: + for h in self.handlers: + if h not in logger.handlers: + logger.addHandler(h) + logger.level = self.level + if logger is not logging.getLogger(): + logger.propagate = self.propagate + + def move(self, logger, replace=True): + """ + Replace ``logger`` attrs with those from source. + + :param logger: target logger + :param replace: if True, remove all handlers from target first; otherwise add if not present. + """ + self.apply(logger, replace=replace) + self.source_logger.propagate = True + self.source_logger.handlers[:] = [] + + def reset(self): + self.apply(self.source_logger) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.reset() diff --git a/airflow/settings.py b/airflow/settings.py index 09d657a6b2dbdb..64678b80181589 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -624,6 +624,8 @@ def initialize(): executor_constants.CELERY_KUBERNETES_EXECUTOR, executor_constants.LOCAL_KUBERNETES_EXECUTOR, } +IS_K8S_EXECUTOR_POD = bool(os.environ.get("AIRFLOW_IS_K8S_EXECUTOR_POD", "")) +"""Will be True if running in kubernetes executor pod.""" HIDE_SENSITIVE_VAR_CONN_FIELDS = conf.getboolean("core", "hide_sensitive_var_conn_fields") diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 9deca0a9cd445f..09fbbbe0976be0 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -30,13 +30,13 @@ from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils.context import Context from airflow.utils.helpers import parse_template_string, render_template_to_string +from airflow.utils.log.logging_mixin import SetContextPropagate from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler from airflow.utils.session import create_session from airflow.utils.state import State if TYPE_CHECKING: from airflow.models import TaskInstance - from airflow.utils.log.logging_mixin import SetContextPropagate class FileTaskHandler(logging.Handler): @@ -62,11 +62,23 @@ def __init__(self, base_log_folder: str, filename_template: str | None = None): # handler, not the one that calls super()__init__. stacklevel=(2 if type(self) == FileTaskHandler else 3), ) + self.maintain_propagate: bool = False + """ + If true, overrides default behavior of setting propagate=False + + :meta private: + """ def set_context(self, ti: TaskInstance) -> None | SetContextPropagate: """ Provide task_instance context to airflow task handler. + Generally speaking returns None. But if attr `maintain_propagate` has + been set to propagate, then returns sentinel MAINTAIN_PROPAGATE. This + has the effect of overriding the default behavior to set `propagate` + to False whenever set_context is called. At time of writing, this + functionality is only used in unit testing. + :param ti: task instance object """ local_loc = self._init_file(ti) @@ -74,7 +86,7 @@ def set_context(self, ti: TaskInstance) -> None | SetContextPropagate: if self.formatter: self.handler.setFormatter(self.formatter) self.handler.setLevel(self.level) - return None + return SetContextPropagate.MAINTAIN_PROPAGATE if self.maintain_propagate else None def emit(self, record): if self.handler: @@ -92,16 +104,17 @@ def _render_filename(self, ti: TaskInstance, try_number: int) -> str: with create_session() as session: dag_run = ti.get_dagrun(session=session) template = dag_run.get_log_template(session=session).filename - str_tpl, jinja_tpl = parse_template_string(template) + str_tpl, jinja_tpl = parse_template_string(template) - if jinja_tpl: - if hasattr(ti, "task"): - context = ti.get_template_context() - else: - context = Context(ti=ti, ts=dag_run.logical_date.isoformat()) - context["try_number"] = try_number - return render_template_to_string(jinja_tpl, context) - elif str_tpl: + if jinja_tpl: + if hasattr(ti, "task"): + context = ti.get_template_context(session=session) + else: + context = Context(ti=ti, ts=dag_run.logical_date.isoformat()) + context["try_number"] = try_number + return render_template_to_string(jinja_tpl, context) + + if str_tpl: try: dag = ti.task.dag except AttributeError: # ti.task is not always set. diff --git a/airflow/utils/log/logging_mixin.py b/airflow/utils/log/logging_mixin.py index b8f5a0871c8a95..85ff71a94f1bfc 100644 --- a/airflow/utils/log/logging_mixin.py +++ b/airflow/utils/log/logging_mixin.py @@ -26,6 +26,8 @@ from logging import Handler, Logger, StreamHandler from typing import IO, cast +from airflow.settings import IS_K8S_EXECUTOR_POD + # 7-bit C1 ANSI escape sequences ANSI_ESCAPE = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") @@ -165,9 +167,10 @@ def isatty(self): class RedirectStdHandler(StreamHandler): """ - This class is like a StreamHandler using sys.stderr/stdout, but always uses + This class is like a StreamHandler using sys.stderr/stdout, but uses whatever sys.stderr/stderr is currently set to rather than the value of - sys.stderr/stdout at handler construction time. + sys.stderr/stdout at handler construction time, except when running a + task in a kubernetes executor pod. """ def __init__(self, stream): @@ -179,13 +182,17 @@ def __init__(self, stream): self._use_stderr = True if "stdout" in stream: self._use_stderr = False - + self._orig_stream = sys.stdout + else: + self._orig_stream = sys.stderr # StreamHandler tries to set self.stream Handler.__init__(self) @property def stream(self): """Returns current stream.""" + if IS_K8S_EXECUTOR_POD: + return self._orig_stream if self._use_stderr: return sys.stderr diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index eb50c3635c903d..180e012e49820b 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -29,6 +29,7 @@ from contextlib import contextmanager, redirect_stdout from pathlib import Path from unittest import mock +from unittest.mock import sentinel import pendulum import pytest @@ -37,6 +38,7 @@ from airflow import DAG from airflow.cli import cli_parser from airflow.cli.commands import task_command +from airflow.cli.commands.task_command import LoggerMutationHelper from airflow.configuration import conf from airflow.exceptions import AirflowException, DagRunNotFound from airflow.models import DagBag, DagRun, Pool, TaskInstance @@ -662,6 +664,48 @@ def test_external_executor_id_present_for_process_run_task(self, mock_local_job) external_executor_id="ABCD12345", ) + @pytest.mark.parametrize("is_k8s", ["true", ""]) + def test_logging_with_run_task_stdout_k8s_executor_pod(self, is_k8s): + """ + When running task --local as k8s executor pod, all logging should make it to stdout. + Otherwise, all logging after "running TI" is redirected to logs (and the actual log + file content is tested elsewhere in this module). + + Unfortunately, to test stdout, we have to test this by running as a subprocess because + the stdout redirection & log capturing behavior is not compatible with pytest's stdout + capturing behavior. Running as subprocess takes pytest out of the equation and + verifies with certainty the behavior. + """ + import subprocess + + with mock.patch.dict("os.environ", AIRFLOW_IS_K8S_EXECUTOR_POD=is_k8s): + with subprocess.Popen( + args=["airflow", *self.task_args, "-S", self.dag_path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) as process: + output, err = process.communicate() + lines = [] + found_start = False + for line_ in output.splitlines(): + line = line_.decode("utf-8") + if "Running 20 + self.assert_log_line("Starting attempt 1 of 1", lines) + self.assert_log_line("Exporting the following env vars", lines) + self.assert_log_line("Log from DAG Logger", lines) + self.assert_log_line("Log from TI Logger", lines) + self.assert_log_line("Log from Print statement", lines, expect_from_logging_mixin=True) + self.assert_log_line("Task exited with return code 0", lines) + else: + # when not k8s executor pod, most output is redirected to logs + assert len(lines) == 1 + @unittest.skipIf(not hasattr(os, "fork"), "Forking not available") def test_logging_with_run_task(self): with conf_vars({("core", "dags_folder"): self.dag_path}): @@ -834,3 +878,68 @@ def test_context_with_run(): text == "_AIRFLOW_PARSING_CONTEXT_DAG_ID=test_parsing_context\n" "_AIRFLOW_PARSING_CONTEXT_TASK_ID=task1\n" ) + + +class TestLoggerMutationHelper: + @pytest.mark.parametrize("target_name", ["test_apply_target", None]) + def test_apply(self, target_name): + """ + Handlers, level and propagate should be applied on target. + """ + src = logging.getLogger(f"test_apply_source_{target_name}") + src.propagate = False + src.addHandler(sentinel.handler) + src.setLevel(-1) + obj = LoggerMutationHelper(src) + tgt = logging.getLogger("test_apply_target") + obj.apply(tgt) + assert tgt.handlers == [sentinel.handler] + assert tgt.propagate is False if target_name else True # root propagate unchanged + assert tgt.level == -1 + + def test_apply_no_replace(self): + """ + Handlers, level and propagate should be applied on target. + """ + src = logging.getLogger("test_apply_source_no_repl") + tgt = logging.getLogger("test_apply_target_no_repl") + h1 = logging.Handler() + h1.name = "h1" + h2 = logging.Handler() + h2.name = "h2" + h3 = logging.Handler() + h3.name = "h3" + src.handlers[:] = [h1, h2] + tgt.handlers[:] = [h2, h3] + LoggerMutationHelper(src).apply(tgt, replace=False) + assert tgt.handlers == [h2, h3, h1] + + def test_move(self): + """Move should apply plus remove source handler, set propagate to True""" + src = logging.getLogger("test_move_source") + src.propagate = False + src.addHandler(sentinel.handler) + src.setLevel(-1) + obj = LoggerMutationHelper(src) + tgt = logging.getLogger("test_apply_target") + obj.move(tgt) + assert tgt.handlers == [sentinel.handler] + assert tgt.propagate is False + assert tgt.level == -1 + assert src.propagate is True and obj.propagate is False + assert src.level == obj.level + assert src.handlers == [] and obj.handlers == tgt.handlers + + def test_reset(self): + src = logging.getLogger("test_move_reset") + src.propagate = True + src.addHandler(sentinel.h1) + src.setLevel(-1) + obj = LoggerMutationHelper(src) + src.propagate = False + src.addHandler(sentinel.h2) + src.setLevel(-2) + obj.reset() + assert src.propagate is True + assert src.handlers == [sentinel.h1] + assert src.level == -1 diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index 797462136a47cc..736ee9c0ec1098 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -20,22 +20,22 @@ import logging import os import time -from logging.config import dictConfig +from contextlib import contextmanager from pathlib import Path from unittest import mock +from unittest.mock import patch import psutil import pytest -from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.jobs.local_task_job import LocalTaskJob from airflow.listeners.listener import get_listener_manager from airflow.models.dagbag import DagBag from airflow.models.taskinstance import TaskInstance from airflow.task.task_runner.standard_task_runner import StandardTaskRunner from airflow.utils import timezone +from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.platform import getuser -from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timeout import timeout from tests.listeners.file_write_listener import FileWriteListener @@ -45,28 +45,40 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) -TASK_FORMAT = "{{%(filename)s:%(lineno)d}} %(levelname)s - %(message)s" - -LOGGING_CONFIG = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "airflow.task": {"format": TASK_FORMAT}, - }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "formatter": "airflow.task", - "stream": "ext://sys.stdout", - }, - }, - "loggers": {"airflow": {"handlers": ["console"], "level": "INFO", "propagate": True}}, -} +TASK_FORMAT = "%(filename)s:%(lineno)d %(levelname)s - %(message)s" +@contextmanager +def propagate_task_logger(): + """ + Set `airflow.task` logger to propagate. + + Apparently, caplog doesn't work if you don't propagate messages to root. + + But the normal behavior of the `airflow.task` logger is not to propagate. + + When freshly configured, the logger is set to propagate. However, + ordinarily when set_context is called, this is set to False. + + To override this behavior, so that the messages make it to caplog, we + must tell the handler to maintain its current setting. + """ + logger = logging.getLogger("airflow.task") + h = logger.handlers[0] + assert isinstance(h, FileTaskHandler) # just to make sure / document + _propagate = h.maintain_propagate + if _propagate is False: + h.maintain_propagate = True + try: + yield + finally: + if _propagate is False: + h.maintain_propagate = _propagate + + +@pytest.mark.usefixtures("reset_logging_config") class TestStandardTaskRunner: - @pytest.fixture(autouse=True, scope="class") - def logging_and_db(self): + def setup_class(self): """ This fixture sets up logging to have a different setup on the way in (as the test environment does not have enough context for the normal @@ -74,15 +86,13 @@ def logging_and_db(self): """ get_listener_manager().clear() clear_db_runs() - dictConfig(LOGGING_CONFIG) yield - airflow_logger = logging.getLogger("airflow") - airflow_logger.handlers = [] clear_db_runs() - dictConfig(DEFAULT_LOGGING_CONFIG) get_listener_manager().clear() - def test_start_and_terminate(self): + @patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file") + def test_start_and_terminate(self, mock_init): + mock_init.return_value = "/tmp/any" local_task_job = mock.Mock() local_task_job.task_instance = mock.MagicMock() local_task_job.task_instance.run_as_user = None @@ -131,38 +141,34 @@ def test_notifies_about_start_and_stop(self): ) dag = dagbag.dags.get("test_example_bash_operator") task = dag.get_task("runme_1") + dag.create_dagrun( + run_id="test", + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + state=State.RUNNING, + start_date=DEFAULT_DATE, + ) + ti = TaskInstance(task=task, run_id="test") + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) + runner = StandardTaskRunner(job1) + runner.start() - with create_session() as session: - dag.create_dagrun( - run_id="test", - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - state=State.RUNNING, - start_date=DEFAULT_DATE, - session=session, - ) - ti = TaskInstance(task=task, run_id="test") - job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) - session.commit() - ti.refresh_from_task(task) - - runner = StandardTaskRunner(job1) - runner.start() - - # Wait until process sets its pgid to be equal to pid - with timeout(seconds=1): - while True: - runner_pgid = os.getpgid(runner.process.pid) - if runner_pgid == runner.process.pid: - break - time.sleep(0.01) - - # Wait till process finishes - assert runner.return_code(timeout=10) is not None - with open(path_listener_writer) as f: - assert f.readline() == "on_starting\n" - assert f.readline() == "before_stopping\n" - - def test_start_and_terminate_run_as_user(self): + # Wait until process makes itself the leader of it's own process group + with timeout(seconds=1): + while True: + runner_pgid = os.getpgid(runner.process.pid) + if runner_pgid == runner.process.pid: + break + time.sleep(0.01) + + # Wait till process finishes + assert runner.return_code(timeout=10) is not None + with open(path_listener_writer) as f: + assert f.readline() == "on_starting\n" + assert f.readline() == "before_stopping\n" + + @patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file") + def test_start_and_terminate_run_as_user(self, mock_init): + mock_init.return_value = "/tmp/any" local_task_job = mock.Mock() local_task_job.task_instance = mock.MagicMock() local_task_job.task_instance.task_id = "task_id" @@ -195,13 +201,15 @@ def test_start_and_terminate_run_as_user(self): assert runner.return_code() is not None - def test_early_reap_exit(self, caplog): + @propagate_task_logger() + @patch("airflow.utils.log.file_task_handler.FileTaskHandler._init_file") + def test_early_reap_exit(self, mock_init, caplog): """ Tests that when a child process running a task is killed externally (e.g. by an OOM error, which we fake here), then we get return code -9 and a log message. """ - # Set up mock task + mock_init.return_value = "/tmp/any" local_task_job = mock.Mock() local_task_job.task_instance = mock.MagicMock() local_task_job.task_instance.task_id = "task_id" @@ -256,45 +264,36 @@ def test_on_kill(self): ) dag = dagbag.dags.get("test_on_kill") task = dag.get_task("task1") + dag.create_dagrun( + run_id="test", + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + state=State.RUNNING, + start_date=DEFAULT_DATE, + ) + ti = TaskInstance(task=task, run_id="test") + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) + runner = StandardTaskRunner(job1) + runner.start() + + with timeout(seconds=3): + while True: + runner_pgid = os.getpgid(runner.process.pid) + if runner_pgid == runner.process.pid: + break + time.sleep(0.01) + + processes = list(self._procs_in_pgroup(runner_pgid)) - with create_session() as session: - dag.create_dagrun( - run_id="test", - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - state=State.RUNNING, - start_date=DEFAULT_DATE, - session=session, - ) - ti = TaskInstance(task=task, run_id="test") - job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) - session.commit() - ti.refresh_from_task(task) - - runner = StandardTaskRunner(job1) - runner.start() - - with timeout(seconds=3): - while True: - runner_pgid = os.getpgid(runner.process.pid) - if runner_pgid == runner.process.pid: - break - time.sleep(0.01) - - processes = list(self._procs_in_pgroup(runner_pgid)) - - logging.info("Waiting for the task to start") - with timeout(seconds=20): - while True: - if os.path.exists(path_on_kill_running): - break - time.sleep(0.01) - logging.info("Task started. Give the task some time to settle") - time.sleep(3) - logging.info("Terminating processes %s belonging to %s group", processes, runner_pgid) - runner.terminate() - session.close() # explicitly close as `create_session`s commit will blow up otherwise - - ti.refresh_from_db() + logging.info("Waiting for the task to start") + with timeout(seconds=20): + while True: + if os.path.exists(path_on_kill_running): + break + time.sleep(0.01) + logging.info("Task started. Give the task some time to settle") + time.sleep(3) + logging.info("Terminating processes %s belonging to %s group", processes, runner_pgid) + runner.terminate() logging.info("Waiting for the on kill killed file to appear") with timeout(seconds=4): @@ -323,35 +322,30 @@ def test_parsing_context(self): dag = dagbag.dags.get("test_parsing_context") task = dag.get_task("task1") - with create_session() as session: - dag.create_dagrun( - run_id="test", - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - state=State.RUNNING, - start_date=DEFAULT_DATE, - session=session, - ) - ti = TaskInstance(task=task, run_id="test") - job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) - session.commit() - ti.refresh_from_task(task) - - runner = StandardTaskRunner(job1) - runner.start() - - # Wait until process sets its pgid to be equal to pid - with timeout(seconds=1): - while True: - runner_pgid = os.getpgid(runner.process.pid) - if runner_pgid == runner.process.pid: - break - time.sleep(0.01) - - assert runner_pgid > 0 - assert runner_pgid != os.getpgid(0), "Task should be in a different process group to us" - processes = list(self._procs_in_pgroup(runner_pgid)) - psutil.wait_procs([runner.process]) - session.close() + dag.create_dagrun( + run_id="test", + data_interval=(DEFAULT_DATE, DEFAULT_DATE), + state=State.RUNNING, + start_date=DEFAULT_DATE, + ) + ti = TaskInstance(task=task, run_id="test") + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) + runner = StandardTaskRunner(job1) + runner.start() + + # Wait until process sets its pgid to be equal to pid + with timeout(seconds=1): + while True: + runner_pgid = os.getpgid(runner.process.pid) + if runner_pgid == runner.process.pid: + break + time.sleep(0.01) + + assert runner_pgid > 0 + assert runner_pgid != os.getpgid(0), "Task should be in a different process group to us" + processes = list(self._procs_in_pgroup(runner_pgid)) + psutil.wait_procs([runner.process]) + for process in processes: assert not psutil.pid_exists(process.pid), f"{process} is still alive" assert runner.return_code() == 0