Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,9 @@ def _service_subprocess(
:param expect_signal: Signal not to log if the task exits with this code.
:returns: The process exit code, or None if it's still alive
"""
events = self.selector.select(timeout=max_wait_time)
# Ensure minimum timeout to prevent CPU spike with tight loop when timeout is 0 or negative
timeout = max(0.01, max_wait_time)
events = self.selector.select(timeout=timeout)
for key, _ in events:
# Retrieve the handler responsible for processing this file object (e.g., stdout, stderr)
socket_handler = key.data
Expand Down
67 changes: 67 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,73 @@ def test_service_subprocess(self, watched_subprocess, mock_process, mocker):
# Validate that `_check_subprocess_exit` is called
mock_process.wait.assert_called_once_with(timeout=0)

def test_max_wait_time_prevents_cpu_spike(self, watched_subprocess, mock_process, monkeypatch):
"""Test that max_wait_time calculation prevents CPU spike when heartbeat timeout is reached."""
# Mock the configuration to reproduce the CPU spike scenario
# Set heartbeat timeout to be very small relative to MIN_HEARTBEAT_INTERVAL
monkeypatch.setattr("airflow.sdk.execution_time.supervisor.HEARTBEAT_TIMEOUT", 1)
monkeypatch.setattr("airflow.sdk.execution_time.supervisor.MIN_HEARTBEAT_INTERVAL", 10)

# Set up a scenario where the last successful heartbeat was a long time ago
# This will cause the heartbeat calculation to result in a negative value
mock_process._last_successful_heartbeat = time.monotonic() - 100 # 100 seconds ago

# Mock process to still be alive (not exited)
mock_process.wait.side_effect = psutil.TimeoutExpired(pid=12345, seconds=0)

# Call _service_subprocess which is used in _monitor_subprocess
# This tests the max_wait_time calculation directly
watched_subprocess._service_subprocess(max_wait_time=0.005) # Very small timeout to verify our fix

# Verify that selector.select was called with a minimum timeout of 0.01
# This proves our fix prevents the timeout=0 scenario that causes CPU spike
watched_subprocess.selector.select.assert_called_once()
call_args = watched_subprocess.selector.select.call_args
timeout_arg = call_args[1]["timeout"] if "timeout" in call_args[1] else call_args[0][0]

# The timeout should be at least 0.01 (our minimum), never 0
assert timeout_arg >= 0.01, f"Expected timeout >= 0.01, got {timeout_arg}"

@pytest.mark.parametrize(
["heartbeat_timeout", "min_interval", "heartbeat_ago", "expected_min_timeout"],
[
# Normal case: heartbeat is recent, should use calculated value
pytest.param(30, 5, 5, 0.01, id="normal_heartbeat"),
# Edge case: heartbeat timeout exceeded, should use minimum
pytest.param(10, 20, 50, 0.01, id="heartbeat_timeout_exceeded"),
# Bug reproduction case: timeout < interval, heartbeat very old
pytest.param(5, 10, 100, 0.01, id="cpu_spike_scenario"),
],
)
def test_max_wait_time_calculation_edge_cases(
self,
watched_subprocess,
mock_process,
monkeypatch,
heartbeat_timeout,
min_interval,
heartbeat_ago,
expected_min_timeout,
):
"""Test max_wait_time calculation in various edge case scenarios."""
monkeypatch.setattr("airflow.sdk.execution_time.supervisor.HEARTBEAT_TIMEOUT", heartbeat_timeout)
monkeypatch.setattr("airflow.sdk.execution_time.supervisor.MIN_HEARTBEAT_INTERVAL", min_interval)

watched_subprocess._last_successful_heartbeat = time.monotonic() - heartbeat_ago
mock_process.wait.side_effect = psutil.TimeoutExpired(pid=12345, seconds=0)

# Call the method and verify timeout is never less than our minimum
watched_subprocess._service_subprocess(
max_wait_time=999
) # Large value, should be overridden by calculation

# Extract the timeout that was actually used
watched_subprocess.selector.select.assert_called_once()
call_args = watched_subprocess.selector.select.call_args
actual_timeout = call_args[1]["timeout"] if "timeout" in call_args[1] else call_args[0][0]

assert actual_timeout >= expected_min_timeout


class TestHandleRequest:
@pytest.fixture
Expand Down