Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions providers/git/src/airflow/providers/git/bundles/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import os
from contextlib import nullcontext
from pathlib import Path
from urllib.parse import urlparse

Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
*,
tracking_ref: str,
subdir: str | None = None,
git_conn_id: str = "git_default",
git_conn_id: str | None = None,
repo_url: str | None = None,
**kwargs,
) -> None:
Expand All @@ -79,16 +80,20 @@ def __init__(
)

self._log.debug("bundle configured")
try:
self.hook = GitHook(git_conn_id=self.git_conn_id, repo_url=self.repo_url)
self.repo_url = self.hook.repo_url
self._log.debug("repo_url updated from hook", repo_url=self.repo_url)
except AirflowException as e:
self._log.warning("Could not create GitHook", conn_id=self.git_conn_id, exc=e)
self.hook: GitHook | None = None
if self.git_conn_id:
try:
self.hook = GitHook(git_conn_id=self.git_conn_id, repo_url=self.repo_url)
except AirflowException as e:
self._log.warning("Could not create GitHook", conn_id=self.git_conn_id, exc=e)
else:
self.repo_url = self.hook.repo_url
self._log.debug("repo_url updated from hook", repo_url=self.repo_url)

def _initialize(self):
with self.lock():
with self.hook.configure_hook_env():
cm = self.hook.configure_hook_env() if self.hook else nullcontext()
with cm:
self._clone_bare_repo_if_required()
self._ensure_version_in_bare_repo()

Expand Down Expand Up @@ -134,7 +139,7 @@ def _clone_bare_repo_if_required(self) -> None:
url=self.repo_url,
to_path=self.bare_repo_path,
bare=True,
env=self.hook.env,
env=self.hook.env if self.hook else None,
)
except GitCommandError as e:
raise AirflowException("Error cloning repository") from e
Expand Down Expand Up @@ -177,18 +182,19 @@ def _has_version(repo: Repo, version: str) -> bool:

def _fetch_bare_repo(self):
refspecs = ["+refs/heads/*:refs/heads/*", "+refs/tags/*:refs/tags/*"]
if self.hook.env:
with self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
self.bare_repo.remotes.origin.fetch(refspecs)
else:
cm = nullcontext()
if self.hook and (cmd := self.hook.env.get("GIT_SSH_COMMAND")):
cm = self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=cmd)
with cm:
self.bare_repo.remotes.origin.fetch(refspecs)

def refresh(self) -> None:
if self.version:
raise AirflowException("Refreshing a specific version is not supported")

with self.lock():
with self.hook.configure_hook_env():
cm = self.hook.configure_hook_env() if self.hook else nullcontext()
with cm:
self._fetch_bare_repo()
self.repo.remotes.origin.fetch(
["+refs/heads/*:refs/remotes/origin/*", "+refs/tags/*:refs/tags/*"]
Expand Down
60 changes: 45 additions & 15 deletions providers/git/tests/unit/git/bundles/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import json
import os
import re
from unittest import mock
Expand Down Expand Up @@ -118,14 +119,14 @@ def test_https_access_token_repo_url_overrides_connection_host_when_provided(sel
assert bundle.repo_url == f"https://{ACCESS_TOKEN}@github.com/apache/zzzairflow"

def test_falls_back_to_connection_host_when_no_repo_url_provided(self):
bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH)
bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref=GIT_DEFAULT_BRANCH)
assert bundle.repo_url == bundle.hook.repo_url

@mock.patch("airflow.providers.git.bundles.git.GitHook")
def test_get_current_version(self, mock_githook, git_repo):
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path
bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH)
bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref=GIT_DEFAULT_BRANCH)

bundle.initialize()

Expand All @@ -146,6 +147,7 @@ def test_get_specific_version(self, mock_githook, git_repo):

bundle = GitDagBundle(
name="test",
git_conn_id=CONN_HTTPS,
version=starting_commit.hexsha,
tracking_ref=GIT_DEFAULT_BRANCH,
)
Expand Down Expand Up @@ -174,6 +176,7 @@ def test_get_tag_version(self, mock_githook, git_repo):

bundle = GitDagBundle(
name="test",
git_conn_id=CONN_HTTPS,
version="test",
tracking_ref=GIT_DEFAULT_BRANCH,
)
Expand All @@ -195,7 +198,7 @@ def test_get_latest(self, mock_githook, git_repo):
repo.index.add([file_path])
repo.index.commit("Another commit")

bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH)
bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref=GIT_DEFAULT_BRANCH)
bundle.initialize()

assert bundle.get_current_version() != starting_commit.hexsha
Expand All @@ -221,7 +224,7 @@ def test_refresh(self, mock_githook, git_repo, amend):
writer.set_value("user", "name", "Test User")
writer.set_value("user", "email", "test@example.com")

bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH)
bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref=GIT_DEFAULT_BRANCH)
bundle.initialize()

assert bundle.get_current_version() == starting_commit.hexsha
Expand Down Expand Up @@ -252,7 +255,7 @@ def test_refresh_tag(self, mock_githook, git_repo):
# add tag
repo.create_tag("test123")

bundle = GitDagBundle(name="test", tracking_ref="test123")
bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref="test123")
bundle.initialize()
assert bundle.get_current_version() == starting_commit.hexsha

Expand All @@ -279,7 +282,7 @@ def test_head(self, mock_githook, git_repo):
mock_githook.return_value.repo_url = repo_path

repo.create_head("test")
bundle = GitDagBundle(name="test", tracking_ref="test")
bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref="test")
bundle.initialize()
assert bundle.repo.head.ref.name == "test"

Expand All @@ -289,6 +292,7 @@ def test_version_not_found(self, mock_githook, git_repo):
mock_githook.return_value.repo_url = repo_path
bundle = GitDagBundle(
name="test",
git_conn_id=CONN_HTTPS,
version="not_found",
tracking_ref=GIT_DEFAULT_BRANCH,
)
Expand All @@ -313,6 +317,7 @@ def test_subdir(self, mock_githook, git_repo):

bundle = GitDagBundle(
name="test",
git_conn_id=CONN_HTTPS,
tracking_ref=GIT_DEFAULT_BRANCH,
subdir=subdir,
)
Expand Down Expand Up @@ -408,7 +413,7 @@ def test_refresh_with_git_connection(self, mock_gitRepo):
def test_view_url(self, mock_gitrepo, repo_url, extra_conn_kwargs, expected_url, session):
session.query(Connection).delete()
conn = Connection(
conn_id="git_default",
conn_id="my_git_connection",
host=repo_url,
conn_type="git",
**(extra_conn_kwargs or {}),
Expand All @@ -417,6 +422,7 @@ def test_view_url(self, mock_gitrepo, repo_url, extra_conn_kwargs, expected_url,
session.commit()
bundle = GitDagBundle(
name="test",
git_conn_id="my_git_connection",
tracking_ref="main",
)
bundle.initialize = mock.MagicMock()
Expand Down Expand Up @@ -505,6 +511,7 @@ def test_view_url_subdir(self, mock_gitrepo, repo_url, extra_conn_kwargs, expect
name="test",
tracking_ref="main",
subdir="subdir",
git_conn_id="git_default",
)
bundle.initialize = mock.MagicMock()
view_url = bundle.view_url("0f0f0f")
Expand All @@ -527,7 +534,7 @@ def test_clone_bare_repo_git_command_error(self, mock_githook):

with mock.patch("airflow.providers.git.bundles.git.Repo.clone_from") as mock_clone:
mock_clone.side_effect = GitCommandError("clone", "Simulated error")
bundle = GitDagBundle(name="test", tracking_ref="main")
bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref="main")
with pytest.raises(
AirflowException,
match=re.escape("Error cloning repository"),
Expand All @@ -548,25 +555,48 @@ def test_clone_repo_no_such_path_error(self, mock_githook):
assert "Repository path: %s not found" in str(exc_info.value)

@patch.dict(os.environ, {"AIRFLOW_CONN_MY_TEST_GIT": '{"host": "something"}'})
@pytest.mark.parametrize("conn_id, should_find", [("my_test_git", True), ("something-else", False)])
def test_repo_url_access_missing_connection_doesnt_error(self, conn_id, should_find):
@pytest.mark.parametrize(
"conn_id, expected_hook_type",
[("my_test_git", GitHook), ("something-else", type(None))],
)
def test_repo_url_access_missing_connection_doesnt_error(self, conn_id, expected_hook_type):
bundle = GitDagBundle(
name="testa",
tracking_ref="main",
git_conn_id=conn_id,
repo_url="some_repo_url",
)
assert bundle.repo_url == "some_repo_url"
if should_find:
assert isinstance(bundle.hook, GitHook)
else:
assert not hasattr(bundle, "hook")
assert isinstance(bundle.hook, expected_hook_type)

@mock.patch("airflow.providers.git.bundles.git.GitHook")
def test_lock_used(self, mock_githook, git_repo):
repo_path, repo = git_repo
mock_githook.return_value.repo_url = repo_path
bundle = GitDagBundle(name="test", tracking_ref=GIT_DEFAULT_BRANCH)
bundle = GitDagBundle(name="test", git_conn_id=CONN_HTTPS, tracking_ref=GIT_DEFAULT_BRANCH)
with mock.patch("airflow.providers.git.bundles.git.GitDagBundle.lock") as mock_lock:
bundle.initialize()
assert mock_lock.call_count == 2 # both initialize and refresh

@pytest.mark.parametrize(
"conn_json, repo_url, expected",
[
(
{"host": "git@github.com:apache/airflow.git"},
"git@github.com:apache/hello.git",
"git@github.com:apache/hello.git",
),
({"host": "git@github.com:apache/airflow.git"}, None, "git@github.com:apache/airflow.git"),
({}, "git@github.com:apache/hello.git", "git@github.com:apache/hello.git"),
],
)
def test_repo_url_precedence(self, conn_json, repo_url, expected):
conn_str = json.dumps(conn_json)
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_TEST_GIT": conn_str}):
bundle = GitDagBundle(
name="test",
tracking_ref="main",
git_conn_id="my_test_git",
repo_url=repo_url,
)
assert bundle.repo_url == expected