Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions airflow-core/src/airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from airflow._shared.timezones import timezone
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
from airflow.configuration import conf
from airflow.dag_processing.bundles.base import BundleUsageTrackingManager
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.dag_processing.collection import update_dag_parsing_results_in_db
from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess
Expand Down Expand Up @@ -226,6 +227,9 @@ class DagFileProcessorManager(LoggingMixin):
_force_refresh_bundles: set[str] = attrs.field(factory=set, init=False)
"""List of bundles that need to be force refreshed in the next loop"""

_stale_bundles_last_cleaned: float = attrs.field(default=0, init=False)
"""Last time we checked for stale bundle versions to clean up"""

_file_parsing_sort_mode: str = attrs.field(
factory=_config_get_factory("dag_processor", "file_parsing_sort_mode")
)
Expand Down Expand Up @@ -384,6 +388,7 @@ def _run_parsing_loop(self):
self._add_callback_to_queue(callback)
self._scan_stale_dags()
DagWarning.purge_inactive_dag_warnings()
self._cleanup_stale_bundle_versions()

# Update number of loop iteration.
self._num_run += 1
Expand Down Expand Up @@ -621,6 +626,32 @@ def _refresh_dag_bundles(self, known_files: dict[str, set[DagFileInfo]]):
self._resort_file_queue()
self._add_new_files_to_queue(known_files=known_files)

def _cleanup_stale_bundle_versions(self) -> None:
"""Clean up stale bundle versions."""
check_interval = conf.getint(
section="dag_processor",
key="stale_bundle_cleanup_interval",
)
if check_interval <= 0:
return

now_seconds = time.monotonic()
next_cleanup = self._stale_bundles_last_cleaned + check_interval
if now_seconds < next_cleanup:
self.log.debug(
"Not time to clean up stale bundle versions yet - skipping. Next cleanup in %.2f seconds",
next_cleanup - now_seconds,
)
return

self._stale_bundles_last_cleaned = now_seconds

self.log.info("Cleaning up stale bundle versions")
try:
BundleUsageTrackingManager().remove_stale_bundle_versions()
except Exception:
self.log.exception("Error cleaning up stale bundle versions")

def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[Path]:
"""Get relative paths for dag files from bundle dir."""
# Build up a list of Python files that could contain DAGs
Expand Down
70 changes: 70 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,3 +1474,73 @@ def test_create_process_passes_bundle_name_to_process_start(
mock_process_start.assert_called_once()
call_kwargs = mock_process_start.call_args.kwargs
assert call_kwargs["bundle_name"] == "testing"

def test_cleanup_stale_bundle_versions_disabled_when_interval_zero(self):
"""Test that cleanup is skipped when stale_bundle_cleanup_interval is 0."""
with conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): "0"}):
manager = DagFileProcessorManager(max_runs=1)

with mock.patch(
"airflow.dag_processing.manager.BundleUsageTrackingManager"
) as mock_cleanup_manager:
manager._cleanup_stale_bundle_versions()

# Cleanup manager should not be instantiated when interval is 0
mock_cleanup_manager.assert_not_called()

def test_cleanup_stale_bundle_versions_runs_on_interval(self):
"""Test that cleanup runs based on the configured interval."""
with conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): "60"}):
manager = DagFileProcessorManager(max_runs=1)

with mock.patch(
"airflow.dag_processing.manager.BundleUsageTrackingManager"
) as mock_cleanup_manager_class:
mock_cleanup_manager = mock_cleanup_manager_class.return_value

# First call should run cleanup
manager._cleanup_stale_bundle_versions()
mock_cleanup_manager.remove_stale_bundle_versions.assert_called_once()

# Immediate second call should skip (not enough time passed)
mock_cleanup_manager.remove_stale_bundle_versions.reset_mock()
manager._cleanup_stale_bundle_versions()
mock_cleanup_manager.remove_stale_bundle_versions.assert_not_called()

def test_cleanup_stale_bundle_versions_runs_after_interval(self):
"""Test that cleanup runs again after the interval has passed."""
with conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): "60"}):
manager = DagFileProcessorManager(max_runs=1)

with mock.patch(
"airflow.dag_processing.manager.BundleUsageTrackingManager"
) as mock_cleanup_manager_class:
mock_cleanup_manager = mock_cleanup_manager_class.return_value

# First call should run cleanup
manager._cleanup_stale_bundle_versions()
assert mock_cleanup_manager.remove_stale_bundle_versions.call_count == 1

# Simulate time passing
manager._stale_bundles_last_cleaned -= 61 # Go back 61 seconds

# Now cleanup should run again
manager._cleanup_stale_bundle_versions()
assert mock_cleanup_manager.remove_stale_bundle_versions.call_count == 2

def test_cleanup_stale_bundle_versions_handles_exceptions(self):
"""Test that exceptions during cleanup are logged but don't crash the manager."""
with conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): "60"}):
manager = DagFileProcessorManager(max_runs=1)

with mock.patch(
"airflow.dag_processing.manager.BundleUsageTrackingManager"
) as mock_cleanup_manager_class:
mock_cleanup_manager = mock_cleanup_manager_class.return_value
mock_cleanup_manager.remove_stale_bundle_versions.side_effect = Exception("Test error")

# Should not raise an exception
manager._cleanup_stale_bundle_versions()

# Cleanup was attempted
mock_cleanup_manager.remove_stale_bundle_versions.assert_called_once()
63 changes: 61 additions & 2 deletions providers/git/src/airflow/providers/git/bundles/git.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a GitDagBundle problem. You'd have to do this for every bundle that supports versioning!

Instead, the dag processor should init a bundle at the right version if a callback ends up coming in. Just like workers do when starting a task.

Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@

import os
import shutil
import tempfile
from contextlib import nullcontext
from pathlib import Path
from urllib.parse import urlparse

import pendulum
import structlog
from git import Repo
from git.exc import BadName, GitCommandError, InvalidGitRepositoryError, NoSuchPathError
from tenacity import retry, retry_if_exception_type, stop_after_attempt

from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.dag_processing.bundles.base import BaseDagBundle, get_bundle_tracking_file
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.git.hooks.git import GitHook

Expand Down Expand Up @@ -186,6 +188,60 @@ def _clone_repo_if_required(self) -> None:
shutil.rmtree(self.repo_path)
raise

def _create_versioned_copy(self) -> None:
"""
Create a versioned copy of the repo after tracking repo is set up.

This ensures the versioned repo exists for DAG callbacks that need to run
with a specific bundle version. It also updates the tracking file so that
stale bundle cleanup (via BundleUsageTrackingManager) knows this version is in use.
"""
current_version = self.repo.head.commit.hexsha
version_path = self.versions_dir / current_version

# Always update the tracking file to mark this version as recently used
self._update_version_tracking_file(current_version)

if os.path.exists(version_path):
self._log.debug("versioned repo already exists", version_path=version_path)
return

self.versions_dir.mkdir(parents=True, exist_ok=True)
self._log.info(
"Creating versioned copy of repository",
version=current_version,
version_path=version_path,
)

# Clone from bare repo to versioned path
versioned_repo = Repo.clone_from(
url=self.bare_repo_path,
to_path=version_path,
)
versioned_repo.head.set_reference(str(versioned_repo.commit(current_version)))
versioned_repo.head.reset(index=True, working_tree=True)

if self.prune_dotgit_folder:
shutil.rmtree(version_path / ".git")

versioned_repo.close()

def _update_version_tracking_file(self, version: str) -> None:
"""
Update the tracking file for a bundle version.

This is used by BundleUsageTrackingManager to determine which versions
are still in use and which can be cleaned up.
"""
tracking_file_path = get_bundle_tracking_file(bundle_name=self.name, version=version)
tracking_file_path.parent.mkdir(parents=True, exist_ok=True)

with tempfile.TemporaryDirectory() as td:
temp_file = Path(td) / tracking_file_path.name
now = pendulum.now(tz=pendulum.UTC)
temp_file.write_text(now.isoformat())
os.replace(temp_file, tracking_file_path)

@retry(
retry=retry_if_exception_type((InvalidGitRepositoryError, GitCommandError)),
stop=stop_after_attempt(2),
Expand Down Expand Up @@ -298,7 +354,10 @@ def refresh(self) -> None:
except GitCommandError as e:
raise RuntimeError("Error pulling submodule from repository") from e

self.repo.close()
# Create a versioned copy of the repo for DAG callbacks
self._create_versioned_copy()

self.repo.close()

@staticmethod
def _convert_git_ssh_url_to_https(url: str) -> str:
Expand Down
108 changes: 105 additions & 3 deletions providers/git/tests/unit/git/bundles/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,10 @@ def test_raises_when_no_repo_url(self):
with pytest.raises(AirflowException, match=f"Connection {CONN_NO_REPO_URL} doesn't have a host url"):
bundle.initialize()

@mock.patch("airflow.providers.git.bundles.git.GitDagBundle._create_versioned_copy")
@mock.patch("airflow.providers.git.bundles.git.GitHook")
@mock.patch("airflow.providers.git.bundles.git.Repo")
def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook):
def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook, mock_create_versioned_copy):
bundle = GitDagBundle(
name="test",
git_conn_id=CONN_ONLY_PATH,
Expand All @@ -410,8 +411,9 @@ def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook):
assert mock_gitRepo.clone_from.call_count == 2
assert mock_gitRepo.return_value.git.checkout.call_count == 1

@mock.patch("airflow.providers.git.bundles.git.GitDagBundle._create_versioned_copy")
@mock.patch("airflow.providers.git.bundles.git.Repo")
def test_refresh_with_git_connection(self, mock_gitRepo):
def test_refresh_with_git_connection(self, mock_gitRepo, mock_create_versioned_copy):
bundle = GitDagBundle(
name="test",
git_conn_id="git_default",
Expand Down Expand Up @@ -894,12 +896,13 @@ def test_initialize_fetches_submodules_when_enabled(
)
mock_rmtree.assert_not_called()

@mock.patch("airflow.providers.git.bundles.git.GitDagBundle._create_versioned_copy")
@mock.patch("airflow.providers.git.bundles.git.shutil.rmtree")
@mock.patch("airflow.providers.git.bundles.git.os.path.exists")
@mock.patch("airflow.providers.git.bundles.git.GitHook")
@mock.patch("airflow.providers.git.bundles.git.Repo")
def test_refresh_fetches_submodules_when_enabled(
self, mock_repo_class, mock_githook, mock_exists, mock_rmtree
self, mock_repo_class, mock_githook, mock_exists, mock_rmtree, mock_create_versioned_copy
):
"""Test that submodules are synced and updated when submodules=True during refresh."""
mock_githook.return_value.repo_url = "git@github.com:apache/airflow.git"
Expand Down Expand Up @@ -976,3 +979,102 @@ def test_submodule_fetch_error_raises_runtime_error(
bundle.initialize()

mock_rmtree.assert_not_called()

@mock.patch("airflow.providers.git.bundles.git.GitHook")
def test_refresh_creates_versioned_copy(self, mock_githook, git_repo):
"""Test that refresh creates a versioned copy of the repo."""
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path

bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref=GIT_DEFAULT_BRANCH)
bundle.initialize()

current_version = bundle.get_current_version()

# Check that versioned copy exists
version_path = bundle.versions_dir / current_version
assert version_path.exists(), "Versioned copy should be created after initialize/refresh"

# Verify the versioned copy has the same files
files_in_versioned = {f.name for f in version_path.iterdir() if f.is_file()}
assert "test_dag.py" in files_in_versioned

assert_repo_is_closed(bundle)

@mock.patch("airflow.providers.git.bundles.git.GitHook")
def test_refresh_creates_new_versioned_copy_on_version_change(self, mock_githook, git_repo):
"""Test that refresh creates a new versioned copy when the version changes."""
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path

with repo.config_writer() as writer:
writer.set_value("user", "name", "Test User")
writer.set_value("user", "email", "test@example.com")

bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref=GIT_DEFAULT_BRANCH)
bundle.initialize()

initial_version = bundle.get_current_version()

# Add a new commit to the repo
file_path = repo_path / "new_file.py"
with open(file_path, "w") as f:
f.write("new content")
repo.index.add([file_path])
repo.index.commit("New commit")

bundle.refresh()

new_version = bundle.get_current_version()
assert new_version != initial_version, "Version should change after refresh"

# Check that both versioned copies exist
assert (bundle.versions_dir / initial_version).exists(), "Initial versioned copy should exist"
assert (bundle.versions_dir / new_version).exists(), "New versioned copy should exist"

# Verify new versioned copy has the new file
new_version_files = {f.name for f in (bundle.versions_dir / new_version).iterdir() if f.is_file()}
assert "new_file.py" in new_version_files

assert_repo_is_closed(bundle)

@mock.patch("airflow.providers.git.bundles.git.GitHook")
def test_versioned_copy_respects_prune_dotgit_folder(self, mock_githook, git_repo):
"""Test that versioned copy respects prune_dotgit_folder setting."""
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path

bundle = GitDagBundle(
name="test_prune",
git_conn_id=CONN_HTTPS,
tracking_ref=GIT_DEFAULT_BRANCH,
prune_dotgit_folder=True,
)
bundle.initialize()
version_path = bundle.versions_dir / bundle.get_current_version()
assert not (version_path / ".git").exists(), ".git should be pruned from versioned copy"

assert_repo_is_closed(bundle)

@mock.patch("airflow.providers.git.bundles.git.GitHook")
def test_versioned_copy_skipped_if_exists(self, mock_githook, git_repo):
"""Test that versioned copy is not recreated if it already exists."""
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path

bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref=GIT_DEFAULT_BRANCH)
bundle.initialize()

version = bundle.get_current_version()
version_path = bundle.versions_dir / version

# Get the modification time of the versioned copy
initial_mtime = version_path.stat().st_mtime

# Refresh again (same version)
bundle.refresh()

# The versioned copy should not have been modified
assert version_path.stat().st_mtime == initial_mtime, "Versioned copy should not be recreated"

assert_repo_is_closed(bundle)
Loading