Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions airflow/providers/google/cloud/sensors/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence

from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.configuration import conf
Expand All @@ -34,6 +35,10 @@
from airflow.utils.context import Context


class PubSubMessageTransformException(AirflowException):
"""Raise when messages failed to convert pubsub received format."""


class PubSubPullSensor(BaseSensorOperator):
"""
Pulls messages from a PubSub subscription and passes them through XCom.
Expand Down Expand Up @@ -170,22 +175,35 @@ def execute(self, context: Context) -> None:
subscription=self.subscription,
max_messages=self.max_messages,
ack_messages=self.ack_messages,
messages_callback=self.messages_callback,
poke_interval=self.poke_interval,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)

def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[str]]) -> str | list[str]:
"""Return immediately and relies on trigger to throw a success event. Callback for the trigger."""
def execute_complete(self, context: Context, event: dict[str, str | list[str]]) -> Any:
"""If messages_callback is provided, execute it; otherwise, return immediately with trigger event message."""
if event["status"] == "success":
self.log.info("Sensor pulls messages: %s", event["message"])
if self.messages_callback:
received_messages = self._convert_to_received_messages(event["message"])
_return_value = self.messages_callback(received_messages, context)
return _return_value

return event["message"]
self.log.info("Sensor failed: %s", event["message"])
raise AirflowException(event["message"])

def _convert_to_received_messages(self, messages: Any) -> list[ReceivedMessage]:
try:
received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in messages]
return received_messages
except Exception as e:
raise PubSubMessageTransformException(
f"Error converting triggerer event message back to received message format: {e}"
)

def _default_message_callback(
self,
pulled_messages: list[ReceivedMessage],
Expand Down
22 changes: 7 additions & 15 deletions airflow/providers/google/cloud/triggers/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,13 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Sequence
from typing import Any, AsyncIterator, Sequence

from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.providers.google.cloud.hooks.pubsub import PubSubAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.utils.context import Context


class PubsubPullTrigger(BaseTrigger):
"""
Expand All @@ -41,11 +38,6 @@ class PubsubPullTrigger(BaseTrigger):
:param ack_messages: If True, each message will be acknowledged
immediately rather than by any downstream tasks
:param gcp_conn_id: Reference to google cloud connection id
:param messages_callback: (Optional) Callback to process received messages.
Its return value will be saved to XCom.
If you are pulling large messages, you probably want to provide a custom callback.
If not provided, the default implementation will convert `ReceivedMessage` objects
into JSON-serializable dicts using `google.protobuf.json_format.MessageToDict` function.
:param poke_interval: polling period in seconds to check for the status
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand All @@ -64,7 +56,6 @@ def __init__(
max_messages: int,
ack_messages: bool,
gcp_conn_id: str,
messages_callback: Callable[[list[ReceivedMessage], Context], Any] | None = None,
poke_interval: float = 10.0,
impersonation_chain: str | Sequence[str] | None = None,
):
Expand All @@ -73,7 +64,6 @@ def __init__(
self.subscription = subscription
self.max_messages = max_messages
self.ack_messages = ack_messages
self.messages_callback = messages_callback
self.poke_interval = poke_interval
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
Expand All @@ -88,7 +78,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"subscription": self.subscription,
"max_messages": self.max_messages,
"ack_messages": self.ack_messages,
"messages_callback": self.messages_callback,
"poke_interval": self.poke_interval,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
Expand All @@ -106,7 +95,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
):
if self.ack_messages:
await self.message_acknowledgement(pulled_messages)
yield TriggerEvent({"status": "success", "message": pulled_messages})

messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages]

yield TriggerEvent({"status": "success", "message": messages_json})
return
self.log.info("Sleeping for %s seconds.", self.poke_interval)
await asyncio.sleep(self.poke_interval)
Expand Down
48 changes: 48 additions & 0 deletions tests/providers/google/cloud/sensors/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from unittest import mock

import pytest
from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.exceptions import AirflowException, TaskDeferred
Expand Down Expand Up @@ -197,3 +198,50 @@ def test_pubsub_pull_sensor_async_execute_complete(self):
with mock.patch.object(operator.log, "info") as mock_log_info:
operator.execute_complete(context={}, event={"status": "success", "message": test_message})
mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message)

@mock.patch("airflow.providers.google.cloud.sensors.pubsub.PubSubHook")
def test_pubsub_pull_sensor_async_execute_complete_use_message_callback(self, mock_hook):
test_message = [
{
"ack_id": "UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q",
"message": {
"data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==",
"message_id": "12165864188103151",
"publish_time": "2024-08-28T11:49:50.962Z",
"attributes": {},
"ordering_key": "",
},
"delivery_attempt": 0,
}
]

received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in test_message]

messages_callback_return_value = "custom_message_from_callback"

def messages_callback(
pulled_messages: list[ReceivedMessage],
context: dict[str, Any],
):
assert pulled_messages == received_messages

assert isinstance(context, dict)
for key in context.keys():
assert isinstance(key, str)

return messages_callback_return_value

operator = PubSubPullSensor(
task_id="test_task",
ack_messages=True,
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
deferrable=True,
messages_callback=messages_callback,
)
mock_hook.return_value.pull.return_value = received_messages

with mock.patch.object(operator.log, "info") as mock_log_info:
resp = operator.execute_complete(context={}, event={"status": "success", "message": test_message})
mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message)
assert resp == messages_callback_return_value
55 changes: 53 additions & 2 deletions tests/providers/google/cloud/triggers/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger
from airflow.triggers.base import TriggerEvent

TEST_POLL_INTERVAL = 10
TEST_GCP_CONN_ID = "google_cloud_default"
Expand All @@ -34,13 +38,25 @@ def trigger():
subscription="subscription",
max_messages=MAX_MESSAGES,
ack_messages=ACK_MESSAGES,
messages_callback=None,
poke_interval=TEST_POLL_INTERVAL,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
)


async def generate_messages(count: int) -> list[ReceivedMessage]:
return [
ReceivedMessage(
ack_id=f"{i}",
message={
"data": f"Message {i}".encode(),
"attributes": {"type": "generated message"},
},
)
for i in range(1, count + 1)
]


class TestPubsubPullTrigger:
def test_async_pubsub_pull_trigger_serialization_should_execute_successfully(self, trigger):
"""
Expand All @@ -54,8 +70,43 @@ def test_async_pubsub_pull_trigger_serialization_should_execute_successfully(sel
"subscription": "subscription",
"max_messages": MAX_MESSAGES,
"ack_messages": ACK_MESSAGES,
"messages_callback": None,
"poke_interval": TEST_POLL_INTERVAL,
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": None,
}

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.pubsub.PubSubAsyncHook.pull")
async def test_async_pubsub_pull_trigger_return_event(self, mock_pull):
mock_pull.return_value = generate_messages(1)
trigger = PubsubPullTrigger(
project_id=PROJECT_ID,
subscription="subscription",
max_messages=MAX_MESSAGES,
ack_messages=False,
poke_interval=TEST_POLL_INTERVAL,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
)

expected_event = TriggerEvent(
{
"status": "success",
"message": [
{
"ack_id": "1",
"message": {
"data": "TWVzc2FnZSAx",
"attributes": {"type": "generated message"},
"message_id": "",
"ordering_key": "",
},
"delivery_attempt": 0,
}
],
}
)

response = await trigger.run().asend(None)

assert response == expected_event