Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable
from uuid import uuid4
from uuid import UUID, uuid4

from azure.core.exceptions import ResourceNotFoundError
from azure.servicebus import (
Expand Down Expand Up @@ -468,7 +468,15 @@ def get_conn(self) -> ServiceBusClient:
self.log.info("Create and returns ServiceBusClient")
return client

def send_message(self, queue_name: str, messages: str | list[str], batch_message_flag: bool = False):
def send_message(
self,
queue_name: str,
messages: str | list[str],
batch_message_flag: bool = False,
message_id: str | None = None,
reply_to: str | None = None,
message_headers: dict[str | bytes, int | float | bytes | bool | str | UUID] | None = None,
):
"""
Use ServiceBusClient Send to send message(s) to a Service Bus Queue.

Expand All @@ -478,38 +486,49 @@ def send_message(self, queue_name: str, messages: str | list[str], batch_message
:param messages: Message which needs to be sent to the queue. It can be string or list of string.
:param batch_message_flag: bool flag, can be set to True if message needs to be
sent as batch message.
:param message_id: Message ID to set on message being sent to the queue. Please note, message_id may only be
set when a single message is sent.
:param reply_to: Reply to which needs to be sent to the queue.
:param message_headers: Headers to add to the message's application_properties field for Azure Service Bus.
"""
if queue_name is None:
raise TypeError("Queue name cannot be None.")
if not messages:
raise ValueError("Messages list cannot be empty.")
if message_id and isinstance(messages, list) and len(messages) != 1:
raise TypeError("Message ID can only be set if a single message is sent.")
with (
self.get_conn() as service_bus_client,
service_bus_client.get_queue_sender(queue_name=queue_name) as sender,
sender,
):
if isinstance(messages, str):
if not batch_message_flag:
msg = ServiceBusMessage(messages)
sender.send_messages(msg)
else:
self.send_batch_message(sender, [messages])
message_creator = lambda msg_body: ServiceBusMessage(
msg_body, message_id=message_id, reply_to=reply_to, application_properties=message_headers
)
message_list = [messages] if isinstance(messages, str) else messages
if not batch_message_flag:
self.send_list_messages(sender, message_list, message_creator)
else:
if not batch_message_flag:
self.send_list_messages(sender, messages)
else:
self.send_batch_message(sender, messages)
self.send_batch_message(sender, message_list, message_creator)

@staticmethod
def send_list_messages(sender: ServiceBusSender, messages: list[str]):
list_messages = [ServiceBusMessage(message) for message in messages]
def send_list_messages(
sender: ServiceBusSender,
messages: list[str],
message_creator: Callable[[str], ServiceBusMessage],
):
list_messages = [message_creator(body) for body in messages]
sender.send_messages(list_messages) # type: ignore[arg-type]

@staticmethod
def send_batch_message(sender: ServiceBusSender, messages: list[str]):
def send_batch_message(
sender: ServiceBusSender,
messages: list[str],
message_creator: Callable[[str], ServiceBusMessage],
):
batch_message = sender.create_message_batch()
for message in messages:
batch_message.add_message(ServiceBusMessage(message))
batch_message.add_message(message_creator(message))
sender.send_messages(batch_message)

def receive_message(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable
from uuid import UUID

from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook
Expand Down Expand Up @@ -100,6 +101,11 @@ class AzureServiceBusSendMessageOperator(BaseOperator):
as batch message it can be set to True.
:param azure_service_bus_conn_id: Reference to the
:ref: `Azure Service Bus connection<howto/connection:azure_service_bus>`.
:param message_id: Message ID to set on message being sent to the queue. Please note, message_id may only be
set when a single message is sent.
:param reply_to: Name of queue or topic the receiver should reply to. Determination of if the reply will be sent to
a queue or a topic should be made out-of-band.
:param message_headers: Headers to add to the message's application_properties field for Azure Service Bus.
"""

template_fields: Sequence[str] = ("queue_name",)
Expand All @@ -112,21 +118,29 @@ def __init__(
message: str | list[str],
batch: bool = False,
azure_service_bus_conn_id: str = "azure_service_bus_default",
message_id: str | None = None,
reply_to: str | None = None,
message_headers: dict[str | bytes, int | float | bytes | bool | str | UUID] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.queue_name = queue_name
self.batch = batch
self.message = message
self.azure_service_bus_conn_id = azure_service_bus_conn_id
self.message_id = message_id
self.reply_to = reply_to
self.message_headers = message_headers

def execute(self, context: Context) -> None:
"""Send Message to the specific queue in Service Bus namespace."""
# Create the hook
hook = MessageHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)

# send message
hook.send_message(self.queue_name, self.message, self.batch)
hook.send_message(
self.queue_name, self.message, self.batch, self.message_id, self.reply_to, self.message_headers
)


class AzureServiceBusReceiveMessageOperator(BaseOperator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,43 @@ def test_send_message(
]
mock_sb_client.assert_has_calls(expected_calls, any_order=False)

@mock.patch(f"{MODULE}.MessageHook.get_conn", autospec=True)
@mock.patch("azure.servicebus.ServiceBusSender", autospec=True)
def test_send_message_with_id_reply_to_and_headers(self, mock_q_sender, mock_sb_client):
"""
Test `send_message` hook function with batch flag and message passed as mocked params,
which can be string or list of string, mock the azure service bus `send_messages` function
"""
sent_messages = []

def mock_send_messages(messages):
nonlocal sent_messages
sent_messages.extend(messages)

mock_sb_client.return_value.__enter__.return_value.get_queue_sender.return_value.__enter__.return_value = mock_q_sender
mock_q_sender.send_messages.side_effect = mock_send_messages

MSG_ID = "test_msg_id"
REPLY_TO = "test_reply_to"
HEADERS = {"test-key": "test-value"}
hook = MessageHook(azure_service_bus_conn_id="azure_service_bus_default")
hook.send_message(
queue_name="test_queue",
messages=MESSAGE,
batch_message_flag=False,
message_id=MSG_ID,
reply_to=REPLY_TO,
message_headers=HEADERS,
)

mock_q_sender.send_messages.assert_called_once()

assert len(sent_messages) == 1
assert str(sent_messages[0]) == MESSAGE
assert sent_messages[0].message_id == MSG_ID
assert sent_messages[0].reply_to == REPLY_TO
assert sent_messages[0].application_properties == HEADERS

@mock.patch(f"{MODULE}.MessageHook.get_conn")
def test_send_message_exception(self, mock_sb_client):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,52 +131,80 @@ def test_delete_queue(self, mock_get_conn):

class TestAzureServiceBusSendMessageOperator:
@pytest.mark.parametrize(
"mock_message, mock_batch_flag",
"mock_message, mock_batch_flag, mock_message_id, mock_reply_to, mock_headers",
[
(MESSAGE, True),
(MESSAGE, False),
(MESSAGE_LIST, True),
(MESSAGE_LIST, False),
(MESSAGE, True, None, None, None),
(MESSAGE, False, "test_message_id", "test_reply_to", {"test_header": "test_value"}),
(MESSAGE_LIST, True, None, None, None),
(MESSAGE_LIST, False, None, None, None),
],
)
def test_init(self, mock_message, mock_batch_flag):
def test_init(self, mock_message, mock_batch_flag, mock_message_id, mock_reply_to, mock_headers):
"""
Test init by creating AzureServiceBusSendMessageOperator with task id, queue_name, message,
batch and asserting with values
batch, message_id, reply_to, and message headers and asserting with values
"""
asb_send_message_queue_operator = AzureServiceBusSendMessageOperator(
task_id="asb_send_message_queue_without_batch",
queue_name=QUEUE_NAME,
message=mock_message,
batch=mock_batch_flag,
message_id=mock_message_id,
reply_to=mock_reply_to,
message_headers=mock_headers,
)
assert asb_send_message_queue_operator.task_id == "asb_send_message_queue_without_batch"
assert asb_send_message_queue_operator.queue_name == QUEUE_NAME
assert asb_send_message_queue_operator.message == mock_message
assert asb_send_message_queue_operator.batch is mock_batch_flag
assert asb_send_message_queue_operator.message_id == mock_message_id
assert asb_send_message_queue_operator.reply_to == mock_reply_to
assert asb_send_message_queue_operator.message_headers == mock_headers

@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.get_conn")
def test_send_message_queue(self, mock_get_conn):
@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_message")
def test_send_message_queue(self, mock_send_message):
"""
Test AzureServiceBusSendMessageOperator with queue name, batch boolean flag, mock
the send_messages of azure service bus function
"""
TASK_ID = "task-id"
MSG_BODY = "test message body"
MSG_ID = None
REPLY_TO = None
HDRS = None
asb_send_message_queue_operator = AzureServiceBusSendMessageOperator(
task_id="asb_send_message_queue",
task_id=TASK_ID,
queue_name=QUEUE_NAME,
message="Test message",
message=MSG_BODY,
batch=False,
)
asb_send_message_queue_operator.execute(None)
expected_calls = [
mock.call()
.__enter__()
.get_queue_sender(QUEUE_NAME)
.__enter__()
.send_messages(ServiceBusMessage("Test message"))
.__exit__()
]
mock_get_conn.assert_has_calls(expected_calls, any_order=False)
expected_calls = [mock.call(QUEUE_NAME, MSG_BODY, False, MSG_ID, REPLY_TO, HDRS)]
mock_send_message.assert_has_calls(expected_calls, any_order=False)

@mock.patch("airflow.providers.microsoft.azure.hooks.asb.MessageHook.send_message")
def test_send_message_queue_with_id_hdrs_and_reply_to(self, mock_send_message):
"""
Test AzureServiceBusSendMessageOperator with queue name, batch boolean flag, mock
the send_messages of azure service bus function
"""
TASK_ID = "task-id"
MSG_ID = "test_message_id"
MSG_BODY = "test message body"
REPLY_TO = "test_reply_to"
HDRS = {"test_header": "test_value"}
asb_send_message_queue_operator = AzureServiceBusSendMessageOperator(
task_id=TASK_ID,
queue_name=QUEUE_NAME,
message=MSG_BODY,
batch=False,
message_id=MSG_ID,
reply_to=REPLY_TO,
message_headers=HDRS,
)
asb_send_message_queue_operator.execute(None)
expected_calls = [mock.call(QUEUE_NAME, MSG_BODY, False, MSG_ID, REPLY_TO, HDRS)]
mock_send_message.assert_has_calls(expected_calls, any_order=False)


class TestAzureServiceBusReceiveMessageOperator:
Expand Down