Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,7 @@ def create_session(self) -> boto3.session.Session:
return self._create_session_with_assume_role(session_kwargs=self.conn.session_kwargs)

def _create_basic_session(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
return boto3.session.Session(
aws_access_key_id=self.conn.aws_access_key_id,
aws_secret_access_key=self.conn.aws_secret_access_key,
aws_session_token=self.conn.aws_session_token,
region_name=self.region_name,
**session_kwargs,
)
return boto3.session.Session(**session_kwargs)

def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
if self.conn.assume_role_method == 'assume_role_with_web_identity':
Expand Down Expand Up @@ -383,19 +377,20 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
def __init__(
self,
aws_conn_id: Optional[str] = default_conn_name,
verify: Union[bool, str, None] = None,
verify: Optional[Union[bool, str]] = None,
region_name: Optional[str] = None,
client_type: Optional[str] = None,
resource_type: Optional[str] = None,
config: Optional[Config] = None,
) -> None:
super().__init__()
self.aws_conn_id = aws_conn_id
self.verify = verify
self.client_type = client_type
self.resource_type = resource_type

self._region_name = region_name
self._config = config
self._verify = verify

@cached_property
def conn_config(self) -> AwsConnectionWrapper:
Expand All @@ -415,9 +410,7 @@ def conn_config(self) -> AwsConnectionWrapper:
)

return AwsConnectionWrapper(
conn=connection or Connection(conn_id=None, conn_type="aws"),
region_name=self._region_name,
botocore_config=self._config,
conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
)

@property
Expand All @@ -430,6 +423,11 @@ def config(self) -> Optional[Config]:
"""Configuration for botocore client read-only property."""
return self.conn_config.botocore_config

@property
def verify(self) -> Optional[Union[bool, str]]:
"""Verify or not SSL certificates boto3 client/resource read-only property."""
return self.conn_config.verify

def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session:
"""Get the underlying boto3.session.Session(region_name=region_name)."""
return SessionFactory(
Expand Down
52 changes: 35 additions & 17 deletions airflow/providers/amazon/aws/secrets/secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,17 @@

import ast
import json
import re
import warnings
from typing import Any, Dict, List, Optional
from urllib.parse import unquote, urlencode

import boto3

from airflow.compat.functools import cached_property
from airflow.models.connection import Connection
from airflow.providers.amazon.aws.utils import get_airflow_version
from airflow.providers.amazon.aws.utils import get_airflow_version, trim_none_values
from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin


def _parse_version(val):
val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val)
return tuple(int(x) for x in val.split('.'))


class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection or Variables from AWS Secrets Manager
Expand All @@ -58,8 +50,17 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
if you provide ``{"config_prefix": "airflow/config"}`` and request config
key ``sql_alchemy_conn``.

You can also pass additional keyword arguments like ``aws_secret_access_key``, ``aws_access_key_id``
or ``region_name`` to this class and they would be passed on to Boto3 client.
You can also pass additional keyword arguments listed in AWS Connection Extra config
to this class, and they would be used for establishing a connection and passed on to Boto3 client.

.. code-block:: ini

[secrets]
backend = airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend
backend_kwargs = {"connections_prefix": "airflow/connections", "region_name": "eu-west-1"}

.. seealso::
:ref:`howto/connection:aws:configuring-the-connection`

There are two ways of storing secrets in Secret Manager for using them with this operator:
storing them as a conn URI in one field, or taking advantage of native approach of Secrets Manager
Expand Down Expand Up @@ -90,7 +91,6 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
:param config_prefix: Specifies the prefix of the secret to read to get Configurations.
If set to None (null value in the configuration), requests for configurations will not be sent to
AWS Secrets Manager. If you don't want a config_prefix, set it as an empty string
:param profile_name: The name of a profile to use. If not given, then the default profile is used.
:param sep: separator used to concatenate secret_prefix and secret_id. Default: "/"
:param full_url_mode: if True, the secrets must be stored as one conn URI in just one field per secret.
If False (set it as false in backend_kwargs), you can store the secret using different
Expand All @@ -110,7 +110,6 @@ def __init__(
connections_prefix: str = 'airflow/connections',
variables_prefix: str = 'airflow/variables',
config_prefix: str = 'airflow/config',
profile_name: Optional[str] = None,
sep: str = "/",
full_url_mode: bool = True,
are_secret_values_urlencoded: Optional[bool] = None,
Expand All @@ -130,7 +129,6 @@ def __init__(
self.config_prefix = config_prefix.rstrip(sep)
else:
self.config_prefix = config_prefix
self.profile_name = profile_name
self.sep = sep
self.full_url_mode = full_url_mode

Expand All @@ -154,14 +152,34 @@ def __init__(
)
self.are_secret_values_urlencoded = are_secret_values_urlencoded
self.extra_conn_words = extra_conn_words or {}

self.profile_name = kwargs.get("profile_name", None)
# Remove client specific arguments from kwargs
self.api_version = kwargs.pop("api_version", None)
self.use_ssl = kwargs.pop("use_ssl", None)

self.kwargs = kwargs

@cached_property
def client(self):
"""Create a Secrets Manager client"""
session = boto3.session.Session(profile_name=self.profile_name)

return session.client(service_name="secretsmanager", **self.kwargs)
from airflow.providers.amazon.aws.hooks.base_aws import SessionFactory
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper

conn_id = f"{self.__class__.__name__}__connection"
conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs)
client_kwargs = trim_none_values(
{
"region_name": conn_config.region_name,
"verify": conn_config.verify,
"endpoint_url": conn_config.endpoint_url,
"api_version": self.api_version,
"use_ssl": self.use_ssl,
}
)

session = SessionFactory(conn=conn_config).create_session()
return session.client(service_name="secretsmanager", **client_kwargs)

@staticmethod
def _format_uri_with_extra(secret, conn_string: str) -> str:
Expand Down
52 changes: 38 additions & 14 deletions airflow/providers/amazon/aws/secrets/systems_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,16 @@
# specific language governing permissions and limitations
# under the License.
"""Objects relating to sourcing connections from AWS SSM Parameter Store"""
import re

import warnings
from typing import Optional

import boto3

from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.utils import get_airflow_version
from airflow.providers.amazon.aws.utils import get_airflow_version, trim_none_values
from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin


def _parse_version(val):
val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val)
return tuple(int(x) for x in val.split('.'))


class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection or Variables from AWS SSM Parameter Store
Expand All @@ -56,15 +49,26 @@ class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
If set to None (null), requests for variables will not be sent to AWS SSM Parameter Store.
:param config_prefix: Specifies the prefix of the secret to read to get Variables.
If set to None (null), requests for configurations will not be sent to AWS SSM Parameter Store.
:param profile_name: The name of a profile to use. If not given, then the default profile is used.

You can also pass additional keyword arguments listed in AWS Connection Extra config
to this class, and they would be used for establish connection and passed on to Boto3 client.

.. code-block:: ini

[secrets]
backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend
backend_kwargs = {"connections_prefix": "airflow/connections", "region_name": "eu-west-1"}

.. seealso::
:ref:`howto/connection:aws:configuring-the-connection`

"""

def __init__(
self,
connections_prefix: str = '/airflow/connections',
variables_prefix: str = '/airflow/variables',
config_prefix: str = '/airflow/config',
profile_name: Optional[str] = None,
**kwargs,
):
super().__init__()
Expand All @@ -80,14 +84,34 @@ def __init__(
self.config_prefix = config_prefix.rstrip('/')
else:
self.config_prefix = config_prefix
self.profile_name = profile_name

self.profile_name = kwargs.get("profile_name", None)
# Remove client specific arguments from kwargs
self.api_version = kwargs.pop("api_version", None)
self.use_ssl = kwargs.pop("use_ssl", None)

self.kwargs = kwargs

@cached_property
def client(self):
"""Create a SSM client"""
session = boto3.Session(profile_name=self.profile_name)
return session.client("ssm", **self.kwargs)
from airflow.providers.amazon.aws.hooks.base_aws import SessionFactory
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper

conn_id = f"{self.__class__.__name__}__connection"
conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs)
client_kwargs = trim_none_values(
{
"region_name": conn_config.region_name,
"verify": conn_config.verify,
"endpoint_url": conn_config.endpoint_url,
"api_version": self.api_version,
"use_ssl": self.use_ssl,
}
)

session = SessionFactory(conn=conn_config).create_session()
return session.client(service_name="ssm", **client_kwargs)

def get_conn_value(self, conn_id: str) -> Optional[str]:
"""
Expand Down
Loading