Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from __future__ import annotations

import logging
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Path, status

from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.deps import JWTBearerDep, get_team_name_dep
from airflow.exceptions import AirflowNotFoundException
from airflow.models.connection import Connection

Expand Down Expand Up @@ -57,10 +58,12 @@ async def has_connection_access(
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the connection"},
},
)
def get_connection(connection_id: str) -> ConnectionResponse:
def get_connection(
connection_id: str, team_name: Annotated[str | None, Depends(get_team_name_dep)]
) -> ConnectionResponse:
"""Get an Airflow connection."""
try:
connection = Connection.get_connection_from_secrets(connection_id)
connection = Connection.get_connection_from_secrets(connection_id, team_name=team_name)
except AirflowNotFoundException:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
Expand Down
28 changes: 18 additions & 10 deletions airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from airflow._shared.module_loading import import_string
from airflow._shared.secrets_masker import mask_secret
from airflow.configuration import ensure_secrets_loaded
from airflow.configuration import conf, ensure_secrets_loaded
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
Expand Down Expand Up @@ -490,13 +490,14 @@ def extra_dejson(self) -> dict:
return self.get_extra_dejson()

@classmethod
def get_connection_from_secrets(cls, conn_id: str) -> Connection:
def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) -> Connection:
"""
Get connection by conn_id.

If `MetastoreBackend` is getting used in the execution context, use Task SDK API.

:param conn_id: connection id
:param team_name: Team name associated to the task trying to access the connection (if any)
:return: connection
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
Expand Down Expand Up @@ -528,20 +529,27 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection:
raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") from None
raise

# check cache first
# enabled only if SecretCache.init() has been called first
if team_name and not conf.getboolean("core", "multi_team"):
raise ValueError(
"Multi-team mode is not configured in the Airflow environment but the task trying to access the connection belongs to a team"
)

from airflow.sdk import SecretCache

try:
uri = SecretCache.get_connection_uri(conn_id)
return Connection(conn_id=conn_id, uri=uri)
except SecretCache.NotPresentException:
pass # continue business
# Disable cache if the variable belongs to a team. We might enable it later
if not team_name:
# check cache first
# enabled only if SecretCache.init() has been called first
try:
uri = SecretCache.get_connection_uri(conn_id)
return Connection(conn_id=conn_id, uri=uri)
except SecretCache.NotPresentException:
pass # continue business

# iterate over backends if not in cache (or expired)
for secrets_backend in ensure_secrets_loaded():
try:
conn = secrets_backend.get_connection(conn_id=conn_id)
conn = secrets_backend.get_connection(conn_id=conn_id, team_name=team_name)
if conn:
SecretCache.save_connection_uri(conn_id, conn.get_uri())
return conn
Expand Down
8 changes: 7 additions & 1 deletion airflow-core/src/airflow/secrets/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
class EnvironmentVariablesBackend(BaseSecretsBackend):
"""Retrieves Connection object and Variable from environment variable."""

def get_conn_value(self, conn_id: str) -> str | None:
def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
if team_name and (
team_var := os.environ.get(f"{CONN_ENV_PREFIX}_{team_name.upper()}___" + conn_id.upper())
):
# Format to set a team specific connection: AIRFLOW_CONN__<TEAM_ID>___<CONN_ID>
return team_var

return os.environ.get(CONN_ENV_PREFIX + conn_id.upper())

def get_variable(self, key: str, team_name: str | None = None) -> str | None:
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/secrets/local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _local_configs(self) -> dict[str, str]:
return {}
return load_configs_dict(self.configs_file)

def get_connection(self, conn_id: str) -> Connection | None:
def get_connection(self, conn_id: str, team_name: str | None = None) -> Connection | None:
if conn_id in self._local_connections:
return self._local_connections[conn_id]
return None
Expand Down
14 changes: 12 additions & 2 deletions airflow-core/src/airflow/secrets/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,27 @@ class MetastoreBackend(BaseSecretsBackend):
"""Retrieves Connection object and Variable from airflow metastore database."""

@provide_session
def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connection | None:
def get_connection(
self, conn_id: str, team_name: str | None = None, session: Session = NEW_SESSION
) -> Connection | None:
"""
Get Airflow Connection from Metadata DB.

:param conn_id: Connection ID
:param team_name: Team name associated to the task trying to access the connection (if any)
:param session: SQLAlchemy Session
:return: Connection Object
"""
from airflow.models import Connection

conn = session.scalar(select(Connection).where(Connection.conn_id == conn_id).limit(1))
conn = session.scalar(
select(Connection)
.where(
Connection.conn_id == conn_id,
or_(Connection.team_name == team_name, Connection.team_name.is_(None)),
)
.limit(1)
)
session.expunge_all()
return conn

Expand Down
8 changes: 4 additions & 4 deletions airflow-core/tests/unit/always/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def setup_method(self) -> None:
def test_get_connection_second_try(self, mock_env_get, mock_meta_get):
mock_env_get.side_effect = [None] # return None
Connection.get_connection_from_secrets("fake_conn_id")
mock_meta_get.assert_called_once_with(conn_id="fake_conn_id")
mock_env_get.assert_called_once_with(conn_id="fake_conn_id")
mock_meta_get.assert_called_once_with(conn_id="fake_conn_id", team_name=None)
mock_env_get.assert_called_once_with(conn_id="fake_conn_id", team_name=None)

@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
def test_get_connection_first_try(self, mock_env_get, mock_meta_get):
mock_env_get.return_value = Connection("something") # returns something
Connection.get_connection_from_secrets("fake_conn_id")
mock_env_get.assert_called_once_with(conn_id="fake_conn_id")
mock_env_get.assert_called_once_with(conn_id="fake_conn_id", team_name=None)
mock_meta_get.assert_not_called()

@conf_vars(
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_backend_fallback_to_env_var(self, mock_get_connection):
conn = Connection.get_connection_from_secrets(conn_id="test_mysql")

# Assert that SystemsManagerParameterStoreBackend.get_conn_uri was called
mock_get_connection.assert_called_once_with(conn_id="test_mysql")
mock_get_connection.assert_called_once_with(conn_id="test_mysql", team_name=None)

assert conn.get_uri() == "mysql://airflow:airflow@host:5432/airflow"

Expand Down
34 changes: 32 additions & 2 deletions airflow-core/tests/unit/models/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_connections

if TYPE_CHECKING:
Expand Down Expand Up @@ -409,8 +410,37 @@ def test_get_connection_from_secrets_metastore_backend(
mock_task_sdk_connection.get.assert_not_called()

# Verify the backends were called
mock_env_backend.assert_called_once_with(conn_id="test_conn")
mock_db_backend.assert_called_once_with(conn_id="test_conn")
mock_env_backend.assert_called_once_with(conn_id="test_conn", team_name=None)
mock_db_backend.assert_called_once_with(conn_id="test_conn", team_name=None)

@pytest.mark.db_test
@conf_vars({("core", "multi_team"): "True"})
@mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None})
@mock.patch("airflow.sdk.Connection")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection")
def test_get_connection_from_secrets_metastore_backend_with_team(
self, mock_db_backend, mock_env_backend, mock_task_sdk_connection, testing_team
):
"""Test the get_connection_from_secrets should call all the backends."""

mock_env_backend.return_value = None
mock_db_backend.return_value = Connection(conn_id="test_conn", conn_type="test", password="pass")

# Mock TaskSDK Connection to verify it is never imported
mock_task_sdk_connection.get.side_effect = Exception("TaskSDKConnection should not be used")

result = Connection.get_connection_from_secrets("test_conn", team_name=testing_team.name)

expected_connection = Connection(conn_id="test_conn", conn_type="test", password="pass")

# Verify the result is from MetastoreBackend
assert result.conn_id == expected_connection.conn_id
assert result.conn_type == expected_connection.conn_type

# Verify the backends were called
mock_env_backend.assert_called_once_with(conn_id="test_conn", team_name=testing_team.name)
mock_db_backend.assert_called_once_with(conn_id="test_conn", team_name=testing_team.name)

@pytest.mark.db_test
def test_get_team_name(self, testing_team: Team, session: Session):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,12 @@ def _standardize_secret_keys(self, secret: dict[str, Any]) -> dict[str, Any]:

return conn_d

def get_conn_value(self, conn_id: str) -> str | None:
def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
"""
Get serialized representation of Connection.

:param conn_id: connection id
:param team_name: Team name associated to the task trying to access the connection (if any)
"""
if self.connections_prefix is None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,12 @@ def client(self):
session = SessionFactory(conn=conn_config).create_session()
return session.client(service_name="ssm", **client_kwargs)

def get_conn_value(self, conn_id: str) -> str | None:
def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
"""
Get param value.

:param conn_id: connection id
:param team_name: Team name associated to the task trying to access the connection (if any)
"""
if self.connections_prefix is None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ def _is_valid_prefix_and_sep(self) -> bool:
prefix = self.connections_prefix + self.sep
return _SecretManagerClient.is_valid_secret_name(prefix)

def get_conn_value(self, conn_id: str) -> str | None:
def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
"""
Get serialized representation of Connection.

:param conn_id: connection id
:param team_name: Team name associated to the task trying to access the connection (if any)
"""
if self.connections_prefix is None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_response(self, conn_id: str) -> dict | None:
if TYPE_CHECKING:
from airflow.models.connection import Connection

def get_connection(self, conn_id: str) -> Connection | None:
def get_connection(self, conn_id: str, team_name: str | None = None) -> Connection | None:
"""
Get connection from Vault as secret.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,12 @@ def client(self) -> SecretClient:
client = SecretClient(vault_url=self.vault_url, credential=credential, **self.kwargs)
return client

def get_conn_value(self, conn_id: str) -> str | None:
def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
"""
Get a serialized representation of Airflow Connection from an Azure Key Vault secret.

:param conn_id: The Airflow connection id to retrieve
:param team_name: Team name associated to the task trying to access the connection (if any)
"""
if self.connections_prefix is None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ def __init__(
self.sep = sep
self.endpoint = endpoint

def get_conn_value(self, conn_id: str) -> str | None:
def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
"""
Retrieve from Secrets Backend a string value representing the Connection object.

:param conn_id: Connection ID
:param team_name: Team name associated to the task trying to access the connection (if any)
:return: Connection Value
"""
if self.connections_prefix is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ def build_path(path_prefix: str, secret_id: str, sep: str = "/") -> str:
"""
return f"{path_prefix}{sep}{secret_id}"

def get_conn_value(self, conn_id: str) -> str | None:
def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
"""
Retrieve from Secrets Backend a string value representing the Connection object.

If the client your secrets backend uses already returns a python dict, you should override
``get_connection`` instead.

:param conn_id: connection id
:param team_name: Team name associated to the task trying to access the connection (if any)
"""
raise NotImplementedError

Expand Down Expand Up @@ -106,14 +107,15 @@ def deserialize_connection(self, conn_id: str, value: str):
return conn_class.from_uri(conn_id=conn_id, uri=value)
return conn_class(conn_id=conn_id, uri=value)

def get_connection(self, conn_id: str):
def get_connection(self, conn_id: str, team_name: str | None = None):
"""
Return connection object with a given ``conn_id``.

:param conn_id: connection id
:param team_name: Team name associated to the task trying to access the connection (if any)
:return: Connection object or None
"""
value = self.get_conn_value(conn_id=conn_id)
value = self.get_conn_value(conn_id=conn_id, team_name=team_name)
if value:
return self.deserialize_connection(conn_id=conn_id, value=value)
return None
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,21 @@ class ExecutionAPISecretsBackend(BaseSecretsBackend):
processes, not in API server/scheduler processes.
"""

def get_conn_value(self, conn_id: str) -> str | None:
def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None:
"""
Get connection URI via SUPERVISOR_COMMS.

Not used since we override get_connection directly.
"""
raise NotImplementedError("Use get_connection instead")

def get_connection(self, conn_id: str) -> Connection | None: # type: ignore[override]
def get_connection(self, conn_id: str, team_name: str | None = None) -> Connection | None: # type: ignore[override]
"""
Return connection object by routing through SUPERVISOR_COMMS.

:param conn_id: connection id
:param team_name: Name of the team associated to the task trying to access the connection.
Unused here because the team name is inferred from the task ID provided in the execution API JWT token.
:return: Connection object or None if not found
"""
from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
Expand Down Expand Up @@ -93,8 +95,8 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None:
Return variable value by routing through SUPERVISOR_COMMS.

:param key: Variable key
:param team_id: ID of the team associated to the task trying to access the variable.
Unused here because the team ID is inferred from the task ID provided in the execution API JWT token.
:param team_name: Name of the team associated to the task trying to access the variable.
Unused here because the team name is inferred from the task ID provided in the execution API JWT token.
:return: Variable value or None if not found
"""
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult
Expand Down
Loading