Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import argparse
from datetime import datetime, timedelta, timezone

from airflow import settings
from airflow.providers.amazon.aws.hooks.eks import EksHook
from airflow.settings import configure_orm

# Presigned STS urls are valid for 15 minutes, set token expiration to 1 minute before it expires for
# some cushion
Expand Down Expand Up @@ -51,9 +53,17 @@ def get_parser():
return parser


def _ensure_orm_configured():
"""Ensure Airflow ORM is configured if engine is not set."""
if not getattr(settings, "engine", None):
configure_orm()


def main():
parser = get_parser()
args = parser.parse_args()
_ensure_orm_configured()

eks_hook = EksHook(aws_conn_id=args.aws_conn_id, region_name=args.region_name)
access_token = eks_hook.fetch_access_token_for_cluster(args.cluster_name)
access_token_expiration = get_expiration_time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,12 @@ def test_run(self, mock_eks_hook, args, expected_aws_conn_id, expected_region_na
aws_conn_id=expected_aws_conn_id, region_name=expected_region_name
)
mock_eks_hook.return_value.fetch_access_token_for_cluster.assert_called_once_with("test-cluster")

@mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.configure_orm")
@mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.settings.Session", None)
@mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.settings.engine", None)
def test_ensure_orm_configured_initializes_orm(self, mock_configure_orm):
import airflow.providers.amazon.aws.utils.eks_get_token as eks_get_token

eks_get_token._ensure_orm_configured()
mock_configure_orm.assert_called_once()
Loading