Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions airflow/secrets/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,48 +23,43 @@

from sqlalchemy import select

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.secrets import BaseSecretsBackend
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.connection import Connection
from airflow.models import Connection


class MetastoreBackend(BaseSecretsBackend):
"""Retrieves Connection object and Variable from airflow metastore database."""

@provide_session
def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connection | None:
return MetastoreBackend._fetch_connection(conn_id, session=session)

@provide_session
def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None:
"""
Get Airflow Variable from Metadata DB.
Get Airflow Connection from Metadata DB.

:param key: Variable Key
:return: Variable Value
:param conn_id: Connection ID
:param session: SQLAlchemy Session
:return: Connection Object
"""
return MetastoreBackend._fetch_variable(key=key, session=session)

@staticmethod
@internal_api_call
@provide_session
def _fetch_connection(conn_id: str, session: Session = NEW_SESSION) -> Connection | None:
from airflow.models.connection import Connection
from airflow.models import Connection

conn = session.scalar(select(Connection).where(Connection.conn_id == conn_id).limit(1))
session.expunge_all()
return conn

@staticmethod
@internal_api_call
@provide_session
def _fetch_variable(key: str, session: Session = NEW_SESSION) -> str | None:
from airflow.models.variable import Variable
def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None:
"""
Get Airflow Variable from Metadata DB.

:param key: Variable Key
:param session: SQLAlchemy Session
:return: Variable Value
"""
from airflow.models import Variable

var_value = session.scalar(select(Variable).where(Variable.key == key).limit(1))
session.expunge_all()
Expand Down