Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 118 additions & 16 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
from __future__ import annotations

import datetime
import inspect
import json
import logging
import os
import uuid
import warnings
from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union

Expand All @@ -47,6 +51,7 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.providers_manager import ProvidersManager
from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.secrets_masker import mask_secret
Expand Down Expand Up @@ -409,6 +414,92 @@ def __init__(
self._config = config
self._verify = verify

@classmethod
def _get_provider_version(cls) -> str:
"""Checks the Providers Manager for the package version."""
try:
manager = ProvidersManager()
hook = manager.hooks[cls.conn_type]
if not hook:
# This gets caught immediately, but without it MyPy complains
# Item "None" of "Optional[HookInfo]" has no attribute "package_name"
# on the following line and static checks fail.
raise ValueError(f"Hook info for {cls.conn_type} not found in the Provider Manager.")
provider = manager.providers[hook.package_name]
return provider.version
except Exception:
# Under no condition should an error here ever cause an issue for the user.
return "Unknown"

@staticmethod
def _find_class_name(target_function_name: str) -> str:
"""
Given a frame off the stack, return the name of the class which made the call.
Note: This method may raise a ValueError or an IndexError, but the calling
method is catching and handling those.
"""
stack = inspect.stack()
# Find the index of the most recent frame which called the provided function name.
target_frame_index = [frame.function for frame in stack].index(target_function_name)
# Pull that frame off the stack.
target_frame = stack[target_frame_index][0]
# Get the local variables for that frame.
frame_variables = target_frame.f_locals["self"]
# Get the class object for that frame.
frame_class_object = frame_variables.__class__
# Return the name of the class object.
return frame_class_object.__name__

def _get_caller(self, target_function_name: str = "execute") -> str:
"""Given a function name, walk the stack and return the name of the class which called it last."""
try:
caller = self._find_class_name(target_function_name)
if caller == "BaseSensorOperator":
# If the result is a BaseSensorOperator, then look for whatever last called "poke".
return self._get_caller("poke")
return caller
except Exception:
# Under no condition should an error here ever cause an issue for the user.
return "Unknown"

@staticmethod
def _generate_dag_key() -> str:
"""
The Object Identifier (OID) namespace is used to salt the dag_id value.
That salted value is used to generate a SHA-1 hash which, by definition,
can not (reasonably) be reversed. No personal data can be inferred or
extracted from the resulting UUID.
"""
try:
dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
return str(uuid.uuid5(uuid.NAMESPACE_OID, dag_id))
except Exception:
# Under no condition should an error here ever cause an issue for the user.
return "00000000-0000-0000-0000-000000000000"

@staticmethod
def _get_airflow_version() -> str:
"""Fetch and return the current Airflow version."""
try:
# This can be a circular import under specific configurations.
# Importing locally to either avoid or catch it if it does happen.
from airflow import __version__ as airflow_version

return airflow_version
except Exception:
# Under no condition should an error here ever cause an issue for the user.
return "Unknown"

def _generate_user_agent_extra_field(self, existing_user_agent_extra: str) -> str:
user_agent_extra_values = [
f"Airflow/{self._get_airflow_version()}",
f"AmPP/{self._get_provider_version()}",
f"Caller/{self._get_caller()}",
f"DagRunKey/{self._generate_dag_key()}",
existing_user_agent_extra or "",
]
return " ".join(user_agent_extra_values).strip()

@cached_property
def conn_config(self) -> AwsConnectionWrapper:
"""Get the Airflow Connection object and wrap it in helper (cached)."""
Expand Down Expand Up @@ -436,9 +527,9 @@ def region_name(self) -> str | None:
return self.conn_config.region_name

@property
def config(self) -> Config | None:
def config(self) -> Config:
"""Configuration for botocore client read-only property."""
return self.conn_config.botocore_config
return self.conn_config.botocore_config or botocore.config.Config()

@property
def verify(self) -> bool | str | None:
Expand All @@ -451,22 +542,36 @@ def get_session(self, region_name: str | None = None) -> boto3.session.Session:
conn=self.conn_config, region_name=region_name, config=self.config
).create_session()

def _get_config(self, config: Config | None = None) -> Config:
"""
No AWS Operators use the config argument to this method.
Keep backward compatibility with other users who might use it
"""
if config is None:
config = deepcopy(self.config)

# ignore[union-attr] is required for this block to appease MyPy
# because the user_agent_extra field is generated at runtime.
user_agent_config = Config(
user_agent_extra=self._generate_user_agent_extra_field(
existing_user_agent_extra=config.user_agent_extra # type: ignore[union-attr]
)
)
return config.merge(user_agent_config) # type: ignore[union-attr]

def get_client_type(
self,
region_name: str | None = None,
config: Config | None = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
client_type = self.client_type

# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
config = self.config

session = self.get_session(region_name=region_name)
return session.client(
client_type, endpoint_url=self.conn_config.endpoint_url, config=config, verify=self.verify
client_type,
endpoint_url=self.conn_config.endpoint_url,
config=self._get_config(config),
verify=self.verify,
)

def get_resource_type(
Expand All @@ -476,15 +581,12 @@ def get_resource_type(
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
resource_type = self.resource_type

# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
config = self.config

session = self.get_session(region_name=region_name)
return session.resource(
resource_type, endpoint_url=self.conn_config.endpoint_url, config=config, verify=self.verify
resource_type,
endpoint_url=self.conn_config.endpoint_url,
config=self._get_config(config),
verify=self.verify,
)

@cached_property
Expand Down
64 changes: 59 additions & 5 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from base64 import b64encode
from datetime import datetime, timedelta, timezone
from unittest import mock
from uuid import UUID

import boto3
import pytest
Expand Down Expand Up @@ -301,6 +302,58 @@ def test_get_session_returns_a_boto3_session(self):

assert table.item_count == 0

@pytest.mark.parametrize(
"client_meta",
[
AwsBaseHook(client_type="s3").get_client_type().meta,
AwsBaseHook(resource_type="dynamodb").get_resource_type().meta.client.meta,
],
)
def test_user_agent_extra_update(self, client_meta):
"""
We are only looking for the keys appended by the AwsBaseHook. A user_agent string
is a number of key/value pairs such as: `BOTO3/1.25.4 AIRFLOW/2.5.0.DEV0 AMPP/6.0.0`.
"""
expected_user_agent_tag_keys = ["Airflow", "AmPP", "Caller", "DagRunKey"]

result_user_agent_tags = client_meta.config.user_agent.split(" ")
result_user_agent_tag_keys = [tag.split("/")[0].lower() for tag in result_user_agent_tags]

for key in expected_user_agent_tag_keys:
assert key.lower() in result_user_agent_tag_keys

@staticmethod
def fetch_tags() -> dict[str:str]:
"""Helper method which creates an AwsBaseHook and returns the user agent string split into a dict."""
user_agent_string = AwsBaseHook(client_type="s3").get_client_type().meta.config.user_agent
# Split the list of {Key}/{Value} into a dict
return dict(tag.split("/") for tag in user_agent_string.split(" "))

@pytest.mark.parametrize("found_classes", [["RandomOperator"], ["BaseSensorOperator", "TestSensor"]])
@mock.patch.object(AwsBaseHook, "_find_class_name")
def test_user_agent_caller_target_function_found(self, mock_class_name, found_classes):
mock_class_name.side_effect = found_classes

user_agent_tags = self.fetch_tags()

assert mock_class_name.call_count == len(found_classes)
assert user_agent_tags["Caller"] == found_classes[-1]

def test_user_agent_caller_target_function_not_found(self):
default_caller_name = "Unknown"

user_agent_tags = self.fetch_tags()

assert user_agent_tags["Caller"] == default_caller_name

@pytest.mark.parametrize("env_var, expected_version", [({"AIRFLOW_CTX_DAG_ID": "banana"}, 5), [{}, None]])
@mock.patch.object(AwsBaseHook, "_get_caller", return_value="Test")
def test_user_agent_dag_run_key_is_hashed_correctly(self, _, env_var, expected_version):
with mock.patch.dict(os.environ, env_var, clear=True):
dag_run_key = self.fetch_tags()["DagRunKey"]

assert UUID(dag_run_key).version == expected_version

@mock.patch.object(AwsBaseHook, "get_connection")
@mock_sts
def test_assume_role(self, mock_get_connection):
Expand Down Expand Up @@ -346,7 +399,7 @@ def mock_assume_role(**kwargs):
hook.get_client_type("s3")

calls_assume_role = [
mock.call.session.Session().client("sts", config=None),
mock.call.session.Session().client("sts", config=mock.ANY),
mock.call.session.Session()
.client()
.assume_role(
Expand Down Expand Up @@ -510,7 +563,7 @@ def mock_assume_role_with_saml(**kwargs):
mock_xpath.assert_called_once_with(xpath)

calls_assume_role_with_saml = [
mock.call.session.Session().client("sts", config=None),
mock.call.session.Session().client("sts", config=mock.ANY),
mock.call.session.Session()
.client()
.assume_role_with_saml(
Expand Down Expand Up @@ -735,7 +788,7 @@ def test_get_session(
mock_session_factory.assert_called_once_with(
conn=hook.conn_config,
region_name=method_region_name,
config=hook_botocore_config,
config=mock.ANY,
)
assert mock_session_factory_instance.create_session.assert_called_once
assert session == MOCK_BOTO3_SESSION
Expand Down Expand Up @@ -807,7 +860,7 @@ def __init__(self, count, quota_retry, **kwargs):

def __call__(self):
"""
Raise an Forbidden until after count threshold has been crossed.
Raise an Exception until after count threshold has been crossed.
Then return True.
"""
if self.counter < self.count:
Expand Down Expand Up @@ -839,7 +892,8 @@ def test_do_nothing_on_non_exception(self):
result = _retryable_test(lambda: 42)
assert result, 42

def test_retry_on_exception(self):
@mock.patch("time.sleep", return_value=0)
def test_retry_on_exception(self, _):
quota_retry = {
"stop_after_delay": 2,
"multiplier": 1,
Expand Down