Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion providers/amazon/src/airflow/providers/amazon/aws/hooks/sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def publish_to_target(
message: str,
subject: str | None = None,
message_attributes: dict | None = None,
message_deduplication_id: str | None = None,
message_group_id: str | None = None,
):
"""
Publish a message to a SNS topic or an endpoint.
Expand All @@ -77,7 +79,10 @@ def publish_to_target(
- str = String
- int, float = Number
- iterable = String.Array

:param message_deduplication_id: Every message must have a unique message_deduplication_id.
This parameter applies only to FIFO (first-in-first-out) topics.
:param message_group_id: Tag that specifies that a message belongs to a specific message group.
This parameter applies only to FIFO (first-in-first-out) topics.
"""
publish_kwargs: dict[str, str | dict] = {
"TargetArn": target_arn,
Expand All @@ -88,6 +93,10 @@ def publish_to_target(
# Construct args this way because boto3 distinguishes from missing args and those set to None
if subject:
publish_kwargs["Subject"] = subject
if message_deduplication_id:
publish_kwargs["MessageDeduplicationId"] = message_deduplication_id
if message_group_id:
publish_kwargs["MessageGroupId"] = message_group_id
if message_attributes:
publish_kwargs["MessageAttributes"] = {
key: _get_message_attribute(val) for key, val in message_attributes.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class SnsPublishOperator(AwsBaseOperator[SnsHook]):
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
:param message_deduplication_id: Every message must have a unique message_deduplication_id.
This parameter applies only to FIFO (first-in-first-out) topics.
:param message_group_id: Tag that specifies that a message belongs to a specific message group.
This parameter applies only to FIFO (first-in-first-out) topics.
"""

aws_hook_class = SnsHook
Expand All @@ -61,6 +65,8 @@ class SnsPublishOperator(AwsBaseOperator[SnsHook]):
"message",
"subject",
"message_attributes",
"message_deduplication_id",
"message_group_id",
)
template_fields_renderers = {"message_attributes": "json"}

Expand All @@ -71,27 +77,35 @@ def __init__(
message: str,
subject: str | None = None,
message_attributes: dict | None = None,
message_deduplication_id: str | None = None,
message_group_id: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.target_arn = target_arn
self.message = message
self.subject = subject
self.message_attributes = message_attributes
self.message_deduplication_id = message_deduplication_id
self.message_group_id = message_group_id

def execute(self, context: Context):
self.log.info(
"Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s",
"Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s\nmessage_deduplication_id=%s\nmessage_group_id=%s",
self.target_arn,
self.aws_conn_id,
self.subject,
self.message_attributes,
self.message,
self.message_deduplication_id,
self.message_group_id,
)

return self.hook.publish_to_target(
target_arn=self.target_arn,
message=self.message,
subject=self.subject,
message_attributes=self.message_attributes,
message_deduplication_id=self.message_deduplication_id,
message_group_id=self.message_group_id,
)
41 changes: 35 additions & 6 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

from airflow.providers.amazon.aws.hooks.sns import SnsHook

MESSAGE = "Hello world"
TOPIC_NAME = "test-topic"
SUBJECT = "test-subject"


@mock_aws
class TestSnsHook:
Expand All @@ -32,9 +36,9 @@ def test_get_conn_returns_a_boto3_connection(self):
def test_publish_to_target_with_subject(self):
hook = SnsHook(aws_conn_id="aws_default")

message = "Hello world"
topic_name = "test-topic"
subject = "test-subject"
message = MESSAGE
topic_name = TOPIC_NAME
subject = SUBJECT
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")

response = hook.publish_to_target(target, message, subject)
Expand All @@ -44,8 +48,8 @@ def test_publish_to_target_with_subject(self):
def test_publish_to_target_with_attributes(self):
hook = SnsHook(aws_conn_id="aws_default")

message = "Hello world"
topic_name = "test-topic"
message = MESSAGE
topic_name = TOPIC_NAME
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")

response = hook.publish_to_target(
Expand All @@ -64,7 +68,7 @@ def test_publish_to_target_with_attributes(self):
def test_publish_to_target_plain(self):
hook = SnsHook(aws_conn_id="aws_default")

message = "Hello world"
message = MESSAGE
topic_name = "test-topic"
target = hook.get_conn().create_topic(Name=topic_name).get("TopicArn")

Expand All @@ -90,3 +94,28 @@ def test_publish_to_target_error(self):
"test-non-iterable": object(),
},
)

def test_publish_to_target_with_deduplication(self):
hook = SnsHook(aws_conn_id="aws_default")

message = MESSAGE
topic_name = TOPIC_NAME + ".fifo"
deduplication_id = "abc"
group_id = "a"
target = (
hook.get_conn()
.create_topic(
Name=topic_name,
Attributes={
"FifoTopic": "true",
"ContentBasedDeduplication": "false",
},
)
.get("TopicArn")
)

response = hook.publish_to_target(
target, message, message_deduplication_id=deduplication_id, message_group_id=group_id
)

assert "MessageId" in response
27 changes: 23 additions & 4 deletions providers/amazon/tests/unit/amazon/aws/operators/test_sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

TASK_ID = "sns_publish_job"
AWS_CONN_ID = "custom_aws_conn"
TARGET_ARN = "arn:aws:sns:eu-central-1:1234567890:test-topic"
TARGET_ARN = "test-topic.fifo"
MESSAGE = "Message to send"
SUBJECT = "Subject to send"
MESSAGE_ATTRIBUTES = {"test-attribute": "Attribute to send"}
Expand Down Expand Up @@ -57,6 +57,8 @@ def test_init(self):
region_name="us-west-1",
verify="/spam/egg.pem",
botocore_config={"read_timeout": 42},
message_deduplication_id="abc",
message_group_id="a",
)
assert op.hook.aws_conn_id == AWS_CONN_ID
assert op.hook._region_name == "us-west-1"
Expand All @@ -65,20 +67,37 @@ def test_init(self):
assert op.hook._config.read_timeout == 42

@mock.patch.object(SnsPublishOperator, "hook")
def test_execute(self, mocked_hook):
@pytest.mark.parametrize(
"message_deduplication_id_,message_group_id_",
[
("abc", "a"),
(None, None),
("abc", None),
(None, "a"),
],
)
def test_execute(self, mocked_hook, message_deduplication_id_, message_group_id_):
hook_response = {"MessageId": "foobar"}
mocked_hook.publish_to_target.return_value = hook_response

op = SnsPublishOperator(**self.default_op_kwargs)
op = SnsPublishOperator(
**self.default_op_kwargs,
message_deduplication_id=message_deduplication_id_,
message_group_id=message_group_id_,
)
assert op.execute({}) == hook_response

mocked_hook.publish_to_target.assert_called_once_with(
message=MESSAGE,
message_attributes=MESSAGE_ATTRIBUTES,
subject=SUBJECT,
target_arn=TARGET_ARN,
message_deduplication_id=message_deduplication_id_,
message_group_id=message_group_id_,
)

def test_template_fields(self):
operator = SnsPublishOperator(**self.default_op_kwargs)
operator = SnsPublishOperator(
**self.default_op_kwargs, message_deduplication_id="abc", message_group_id="a"
)
validate_template_fields(operator)