Skip to content
Merged
28 changes: 18 additions & 10 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.helpers import chunks

T = TypeVar("T", bound=Callable)
Expand Down Expand Up @@ -1052,36 +1053,43 @@ def get_bucket_tagging(self, bucket_name: str | None = None) -> list[dict[str, s
@provide_bucket_name
def put_bucket_tagging(
self,
tag_set: list[dict[str, str]] | None = None,
tag_set: dict[str, str] | list[dict[str, str]] | None = None,
key: str | None = None,
value: str | None = None,
bucket_name: str | None = None,
) -> None:
"""
Overwrites the existing TagSet with provided tags. Must provide either a TagSet or a key/value pair.
Overwrites the existing TagSet with provided tags.
Must provide a TagSet, a key/value pair, or both.

.. seealso::
- :external+boto3:py:meth:`S3.Client.put_bucket_tagging`

:param tag_set: A List containing the key/value pairs for the tags.
:param tag_set: A dictionary containing the key/value pairs for the tags,
or a list already formatted for the API
:param key: The Key for the new TagSet entry.
:param value: The Value for the new TagSet entry.
:param bucket_name: The name of the bucket.

:return: None
"""
self.log.info("S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key, value, tag_set)
if not tag_set:
tag_set = []
formatted_tags = format_tags(tag_set)

if key and value:
tag_set.append({"Key": key, "Value": value})
elif not tag_set or (key or value):
message = "put_bucket_tagging() requires either a predefined TagSet or a key/value pair."
formatted_tags.append({"Key": key, "Value": value})
elif key or value:
message = (
"Key and Value must be specified as a pair. "
f"Only one of the two had a value (key: '{key}', value: '{value}')"
)
self.log.error(message)
raise ValueError(message)

self.log.info("Tagging S3 Bucket %s with %s", bucket_name, formatted_tags)

try:
s3_client = self.get_conn()
s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": tag_set})
s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": formatted_tags})
except ClientError as e:
self.log.error(e)
raise e
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils import timezone


Expand Down Expand Up @@ -1100,9 +1101,7 @@ def start_pipeline(

:return: the ARN of the pipeline execution launched.
"""
if pipeline_params is None:
pipeline_params = {}
formatted_params = [{"Name": kvp[0], "Value": kvp[1]} for kvp in pipeline_params.items()]
formatted_params = format_tags(pipeline_params, key_label="Name")

try:
res = self.conn.start_pipeline_execution(
Expand Down
32 changes: 18 additions & 14 deletions airflow/providers/amazon/aws/operators/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.providers.amazon.aws.utils.tags import format_tags

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -64,7 +65,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
:param db_type: Type of the DB - either "instance" or "cluster"
:param db_identifier: The identifier of the instance or cluster that you want to create the snapshot of
:param db_snapshot_identifier: The identifier for the DB snapshot
:param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True)
"""
Expand All @@ -77,7 +78,7 @@ def __init__(
db_type: str,
db_identifier: str,
db_snapshot_identifier: str,
tags: Sequence[TagTypeDef] | None = None,
tags: Sequence[TagTypeDef] | dict | None = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_conn_id",
**kwargs,
Expand All @@ -86,7 +87,7 @@ def __init__(
self.db_type = RdsDbType(db_type)
self.db_identifier = db_identifier
self.db_snapshot_identifier = db_snapshot_identifier
self.tags = tags or []
self.tags = tags
self.wait_for_completion = wait_for_completion

def execute(self, context: Context) -> str:
Expand All @@ -97,11 +98,12 @@ def execute(self, context: Context) -> str:
self.db_snapshot_identifier,
)

formatted_tags = format_tags(self.tags)
if self.db_type.value == "instance":
create_instance_snap = self.hook.conn.create_db_snapshot(
DBInstanceIdentifier=self.db_identifier,
DBSnapshotIdentifier=self.db_snapshot_identifier,
Tags=self.tags,
Tags=formatted_tags,
)
create_response = json.dumps(create_instance_snap, default=str)
if self.wait_for_completion:
Expand All @@ -110,7 +112,7 @@ def execute(self, context: Context) -> str:
create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
DBClusterIdentifier=self.db_identifier,
DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
Tags=self.tags,
Tags=formatted_tags,
)
create_response = json.dumps(create_cluster_snap, default=str)
if self.wait_for_completion:
Expand All @@ -132,7 +134,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
:param source_db_snapshot_identifier: The identifier of the source snapshot
:param target_db_snapshot_identifier: The identifier of the target snapshot
:param kms_key_id: The AWS KMS key identifier for an encrypted DB snapshot
:param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param copy_tags: Whether to copy all tags from the source snapshot to the target snapshot (default False)
:param pre_signed_url: The URL that contains a Signature Version 4 signed request
Expand All @@ -159,7 +161,7 @@ def __init__(
source_db_snapshot_identifier: str,
target_db_snapshot_identifier: str,
kms_key_id: str = "",
tags: Sequence[TagTypeDef] | None = None,
tags: Sequence[TagTypeDef] | dict | None = None,
copy_tags: bool = False,
pre_signed_url: str = "",
option_group_name: str = "",
Expand All @@ -175,7 +177,7 @@ def __init__(
self.source_db_snapshot_identifier = source_db_snapshot_identifier
self.target_db_snapshot_identifier = target_db_snapshot_identifier
self.kms_key_id = kms_key_id
self.tags = tags or []
self.tags = tags
self.copy_tags = copy_tags
self.pre_signed_url = pre_signed_url
self.option_group_name = option_group_name
Expand All @@ -190,12 +192,13 @@ def execute(self, context: Context) -> str:
self.target_db_snapshot_identifier,
)

formatted_tags = format_tags(self.tags)
if self.db_type.value == "instance":
copy_instance_snap = self.hook.conn.copy_db_snapshot(
SourceDBSnapshotIdentifier=self.source_db_snapshot_identifier,
TargetDBSnapshotIdentifier=self.target_db_snapshot_identifier,
KmsKeyId=self.kms_key_id,
Tags=self.tags,
Tags=formatted_tags,
CopyTags=self.copy_tags,
PreSignedUrl=self.pre_signed_url,
OptionGroupName=self.option_group_name,
Expand All @@ -212,7 +215,7 @@ def execute(self, context: Context) -> str:
SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
TargetDBClusterSnapshotIdentifier=self.target_db_snapshot_identifier,
KmsKeyId=self.kms_key_id,
Tags=self.tags,
Tags=formatted_tags,
CopyTags=self.copy_tags,
PreSignedUrl=self.pre_signed_url,
SourceRegion=self.source_region,
Expand Down Expand Up @@ -403,7 +406,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
`USER Events <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Events.Messages.html>`__
:param source_ids: The list of identifiers of the event sources for which events are returned
:param enabled: A value that indicates whether to activate the subscription (default True)l
:param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
:param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True)
"""
Expand All @@ -426,7 +429,7 @@ def __init__(
event_categories: Sequence[str] | None = None,
source_ids: Sequence[str] | None = None,
enabled: bool = True,
tags: Sequence[TagTypeDef] | None = None,
tags: Sequence[TagTypeDef] | dict | None = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
Expand All @@ -439,20 +442,21 @@ def __init__(
self.event_categories = event_categories or []
self.source_ids = source_ids or []
self.enabled = enabled
self.tags = tags or []
self.tags = tags
self.wait_for_completion = wait_for_completion

def execute(self, context: Context) -> str:
self.log.info("Creating event subscription '%s' to '%s'", self.subscription_name, self.sns_topic_arn)

formatted_tags = format_tags(self.tags)
create_subscription = self.hook.conn.create_event_subscription(
SubscriptionName=self.subscription_name,
SnsTopicArn=self.sns_topic_arn,
SourceType=self.source_type,
EventCategories=self.event_categories,
SourceIds=self.source_ids,
Enabled=self.enabled,
Tags=self.tags,
Tags=formatted_tags,
)

if self.wait_for_completion:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class S3PutBucketTaggingOperator(BaseOperator):
If a key is provided, a value must be provided as well.
:param value: The value portion of the key/value pair for a tag to be added.
If a value is provided, a key must be provided as well.
:param tag_set: A List of key/value pairs.
:param tag_set: A dictionary containing the tags, or a List of key/value pairs.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -179,7 +179,7 @@ def __init__(
bucket_name: str,
key: str | None = None,
value: str | None = None,
tag_set: list[dict[str, str]] | None = None,
tag_set: dict | list[dict[str, str]] | None = None,
aws_conn_id: str | None = "aws_default",
**kwargs,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.json import AirflowJsonEncoder

if TYPE_CHECKING:
Expand Down Expand Up @@ -1090,11 +1091,10 @@ def __init__(

def execute(self, context: Context) -> str:
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
tags_set = [{"Key": kvp[0], "Value": kvp[1]} for kvp in self.tags.items()]
params = {
"ExperimentName": self.name,
"Description": self.description,
"Tags": tags_set,
"Tags": format_tags(self.tags),
}
ans = sagemaker_hook.conn.create_experiment(**trim_none_values(params))
arn = ans["ExperimentArn"]
Expand Down
38 changes: 38 additions & 0 deletions airflow/providers/amazon/aws/utils/tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Any


def format_tags(source: Any, *, key_label: str = "Key", value_label: str = "Value"):
"""
If given a dictionary, formats it as an array of objects with a key and a value field to be passed to boto
calls that expect this format.
Else, assumes that it's already in the right format and returns it as is. We do not validate
the format here since it's done by boto anyway, and the error wouldn't be clearer if thrown from here.

:param source: a dict from which keys and values are read
:param key_label: optional, the label to use for keys if not "Key"
:param value_label: optional, the label to use for values if not "Value"
"""
if source is None:
return []
elif isinstance(source, dict):
return [{key_label: kvp[0], value_label: kvp[1]} for kvp in source.items()]
else:
return source
9 changes: 9 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,15 @@ def test_put_bucket_tagging_with_valid_set(self):

assert hook.get_bucket_tagging(bucket_name="new_bucket") == tag_set

@mock_s3
def test_put_bucket_tagging_with_dict(self):
hook = S3Hook()
hook.create_bucket(bucket_name="new_bucket")
tag_set = {"Color": "Green"}
hook.put_bucket_tagging(bucket_name="new_bucket", tag_set=tag_set)

assert hook.get_bucket_tagging(bucket_name="new_bucket") == [{"Key": "Color", "Value": "Green"}]

@mock_s3
def test_put_bucket_tagging_with_pair(self):
hook = S3Hook()
Expand Down