Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions airflow-core/src/airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class LocalExecutor(BaseExecutor):
"""

is_local: bool = True
is_mp_using_fork: bool = multiprocessing.get_start_method() == "fork"

serve_logs: bool = True

Expand All @@ -163,6 +164,11 @@ def start(self) -> None:
# (it looks like an int to python)
self._unread_messages = multiprocessing.Value(ctypes.c_uint)

if self.is_mp_using_fork:
# This creates the maximum number of worker processes (parallelism) at once
# to minimize gc freeze/unfreeze cycles when using fork in multiprocessing
self._spawn_workers_with_gc_freeze(self.parallelism)

def _check_workers(self):
# Reap any dead workers
to_remove = set()
Expand All @@ -186,9 +192,14 @@ def _check_workers(self):
# via `sync()` a few times before the spawned process actually starts picking up messages. Try not to
# create too much
if num_outstanding and len(self.workers) < self.parallelism:
# This only creates one worker, which is fine as we call this directly after putting a message on
# activity_queue in execute_async
self._spawn_worker()
if self.is_mp_using_fork:
# This creates the maximum number of worker processes at once
# to minimize gc freeze/unfreeze cycles when using fork in multiprocessing
self._spawn_workers_with_gc_freeze(self.parallelism - len(self.workers))
else:
# This only creates one worker, which is fine as we call this directly after putting a message on
# activity_queue in execute_async when using spawn in multiprocessing
self._spawn_worker()

def _spawn_worker(self):
p = multiprocessing.Process(
Expand All @@ -205,6 +216,26 @@ def _spawn_worker(self):
assert p.pid # Since we've called start
self.workers[p.pid] = p

def _spawn_workers_with_gc_freeze(self, spawn_number):
"""
Freeze the GC before forking worker process and unfreeze it after forking.

This is done to prevent memory increase due to COW (Copy-on-Write) by moving all
existing objects to the permanent generation before forking the process. After forking,
unfreeze is called to ensure there is no impact on gc operations
in the original running process.

Ref: https://docs.python.org/3/library/gc.html#gc.freeze
"""
import gc

gc.freeze()
try:
for _ in range(spawn_number):
self._spawn_worker()
finally:
gc.unfreeze()

def sync(self) -> None:
"""Sync will get called periodically by the heartbeat method."""
self._read_results()
Expand Down
27 changes: 24 additions & 3 deletions airflow-core/tests/unit/executors/test_local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import gc
import multiprocessing
import os
from unittest import mock
Expand Down Expand Up @@ -44,6 +45,11 @@


class TestLocalExecutor:
"""
When the executor is started, end() must be called before the test finishes.
Otherwise, subprocesses will remain running, preventing the test from terminating and causing a timeout.
"""

TEST_SUCCESS_COMMANDS = 5

def test_supports_sentry(self):
Expand All @@ -55,6 +61,20 @@ def test_is_local_default_value(self):
def test_serve_logs_default_value(self):
assert LocalExecutor.serve_logs

@skip_spawn_mp_start
@mock.patch.object(gc, "unfreeze")
@mock.patch.object(gc, "freeze")
def test_executor_worker_spawned(self, mock_freeze, mock_unfreeze):
executor = LocalExecutor(parallelism=5)
executor.start()

mock_freeze.assert_called_once()
mock_unfreeze.assert_called_once()

assert len(executor.workers) == 5

executor.end()

@skip_spawn_mp_start
@mock.patch("airflow.sdk.execution_time.supervisor.supervise")
def test_execution(self, mock_supervise):
Expand Down Expand Up @@ -86,11 +106,12 @@ def fake_supervise(ti, **kwargs):
mock_supervise.side_effect = fake_supervise

executor = LocalExecutor(parallelism=2)
executor.start()

assert executor.result_queue.empty()

with spy_on(executor._spawn_worker) as spawn_worker:
executor.start()

assert executor.result_queue.empty()

for ti in success_tis:
executor.queue_workload(
workloads.ExecuteTask(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_spawn_worker_when_needed(setup_executor):
executor.activity_queue.empty.return_value = False
executor.workers = {}
executor._check_workers()
executor._spawn_worker.assert_called_once()
executor._spawn_worker.assert_called()


def test_no_spawn_if_parallelism_reached(setup_executor):
Expand Down Expand Up @@ -133,4 +133,4 @@ def test_spawn_worker_when_we_have_parallelism_left(setup_executor):
executor.activity_queue.empty.return_value = False
executor._spawn_worker.reset_mock()
executor._check_workers()
executor._spawn_worker.assert_called_once()
executor._spawn_worker.assert_called()