Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented MSGraphSensor as a deferrable sensor #39304

Merged
merged 30 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e32422e
refactor: Implement default response handler method and added test wh…
davidblain-infrabel Apr 29, 2024
c3c021f
refactor: Reformatted some code to comply to static checks
davidblain-infrabel Apr 29, 2024
445b368
refactor: Changed debugging level to debug for printing response in o…
davidblain-infrabel Apr 29, 2024
e0d6def
docs: Added example on how to refresh a PowerBI dataset using the MSG…
davidblain-infrabel Apr 29, 2024
3941293
refactor: Changed some info logging statements to debug
davidblain-infrabel Apr 29, 2024
863d801
refactor: Changed some info logging statements to debug
davidblain-infrabel Apr 29, 2024
cf18bbe
fix: Fixed mock_json_response
davidblain-infrabel Apr 29, 2024
769cf65
Merge branch 'main' into feature/default_response_handler
dabla Apr 29, 2024
4f6e08f
refactor: Return content if response is not a JSON
davidblain-infrabel Apr 29, 2024
4149b8f
refactor: Make sure the operator passes the response_handler to the t…
davidblain-infrabel Apr 29, 2024
86aec51
refactor: Should use get instead of directly _getitem_ brackets as pa…
davidblain-infrabel Apr 29, 2024
bf65682
refactor: If event has status failure then the sensor should stop the…
davidblain-infrabel Apr 29, 2024
e8b5af2
refactor: Changed default_event_processor as not all responses have t…
davidblain-infrabel Apr 29, 2024
a9778ec
refactor: Changed default_event_processor as not all responses have t…
davidblain-infrabel Apr 29, 2024
3c4db25
Merge branch 'main' into feature/default_response_handler
dabla Apr 30, 2024
696b2c9
refactor: Removed response_handler parameter as lambda cannot be seri…
davidblain-infrabel Apr 30, 2024
7744702
refactor: Changed some logging statements
davidblain-infrabel Apr 30, 2024
6729187
refactor: Updated PowerBI dataset refresh example
davidblain-infrabel Apr 30, 2024
64ad330
refactor: Fixed 2 static check errors
davidblain-infrabel Apr 30, 2024
8008e8f
refactor: Refactored MSGraphSensor as a real async sensor
davidblain-infrabel Apr 30, 2024
a74d810
refactor: Changed logging level of sensor statements back to debug
davidblain-infrabel Apr 30, 2024
6f4365b
refactor: Fixed 2 static checks
davidblain-infrabel Apr 30, 2024
d41eb3a
refactor: Changed docstring hook
davidblain-infrabel May 1, 2024
76add21
refactor: Put docstring on one line
davidblain-infrabel May 1, 2024
d1fbc0f
Merge branch 'main' into feature/default_response_handler
dabla May 1, 2024
9efd969
Merge branch 'main' into feature/default_response_handler
dabla May 2, 2024
66a3d90
Merge branch 'main' into feature/default_response_handler
dabla May 2, 2024
68ccee9
Merge branch 'main' into feature/default_response_handler
dabla May 2, 2024
ceb102c
Merge branch 'main' into feature/default_response_handler
dabla May 3, 2024
7dd1033
Merge branch 'main' into feature/default_response_handler
dabla May 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 18 additions & 23 deletions airflow/providers/microsoft/azure/hooks/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from __future__ import annotations

import json
from contextlib import suppress
from http import HTTPStatus
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any
from urllib.parse import quote, urljoin, urlparse

import httpx
Expand Down Expand Up @@ -51,18 +53,17 @@
from airflow.models import Connection


class CallableResponseHandler(ResponseHandler):
"""
CallableResponseHandler executes the passed callable_function with response as parameter.

param callable_function: Function that is applied to the response.
"""
class DefaultResponseHandler(ResponseHandler):
"""DefaultResponseHandler returns JSON payload or content in bytes or response headers."""

def __init__(
self,
callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any],
):
self.callable_function = callable_function
@staticmethod
def get_value(response: NativeResponseType) -> Any:
with suppress(JSONDecodeError):
return response.json()
content = response.content
if not content:
return {key: value for key, value in response.headers.items()}
return content

async def handle_response_async(
self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
Expand All @@ -73,7 +74,7 @@ async def handle_response_async(
param response: The type of the native response object.
param error_map: The error dict to use in case of a failed request.
"""
value = self.callable_function(response, error_map)
value = self.get_value(response)
if response.status_code not in {200, 201, 202, 204, 302}:
message = value or response.reason_phrase
status_code = HTTPStatus(response.status_code)
Expand Down Expand Up @@ -269,20 +270,18 @@ async def run(
self,
url: str = "",
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
):
self.log.info("Executing url '%s' as '%s'", url, method)

response = await self.get_conn().send_primitive_async(
request_info=self.request_information(
url=url,
response_type=response_type,
response_handler=response_handler,
path_parameters=path_parameters,
method=method,
query_parameters=query_parameters,
Expand All @@ -293,17 +292,14 @@ async def run(
error_map=self.error_mapping(),
)

self.log.debug("response: %s", response)
self.log.info("response: %s", response)

return response

def request_information(
self,
url: str,
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
Expand All @@ -323,12 +319,11 @@ def request_information(
request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}"
if not response_type:
request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption(
response_handler=CallableResponseHandler(response_handler)
response_handler=DefaultResponseHandler()
)
headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS
for header_name, header_value in headers.items():
request_information.headers.try_add(header_name=header_name, header_value=header_value)
self.log.info("data: %s", data)
if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str):
request_information.content = data
elif data:
Expand Down
15 changes: 2 additions & 13 deletions airflow/providers/microsoft/azure/operators/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@

from kiota_abstractions.request_adapter import ResponseType
from kiota_abstractions.request_information import QueryParams
from kiota_abstractions.response_handler import NativeResponseType
from kiota_abstractions.serialization import ParsableFactory
from msgraph_core import APIVersion

from airflow.utils.context import Context
Expand All @@ -59,9 +57,6 @@ class MSGraphAsyncOperator(BaseOperator):
:param url: The url being executed on the Microsoft Graph API (templated).
:param response_type: The expected return type of the response as a string. Possible value are: `bytes`,
`str`, `int`, `float`, `bool` and `datetime` (default is None).
:param response_handler: Function to convert the native HTTPX response returned by the hook (default is
lambda response, error_map: response.json()). The default expression will convert the native response
to JSON. If response_type parameter is specified, then the response_handler will be ignored.
:param method: The HTTP method being used to do the REST call (default is GET).
:param conn_id: The HTTP Connection ID to run the operator against (templated).
:param key: The key that will be used to store `XCom's` ("return_value" is default).
Expand Down Expand Up @@ -94,9 +89,6 @@ def __init__(
*,
url: str,
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
Expand All @@ -116,7 +108,6 @@ def __init__(
super().__init__(**kwargs)
self.url = url
self.response_type = response_type
self.response_handler = response_handler
self.path_parameters = path_parameters
self.url_template = url_template
self.method = method
Expand All @@ -134,7 +125,6 @@ def __init__(
self.results: list[Any] | None = None

def execute(self, context: Context) -> None:
self.log.info("Executing url '%s' as '%s'", self.url, self.method)
self.defer(
trigger=MSGraphTrigger(
url=self.url,
Expand Down Expand Up @@ -167,14 +157,14 @@ def execute_complete(
self.log.debug("context: %s", context)

if event:
self.log.info("%s completed with %s: %s", self.task_id, event.get("status"), event)
self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event)

if event.get("status") == "failure":
raise AirflowException(event.get("message"))

response = event.get("response")

self.log.info("response: %s", response)
self.log.debug("response: %s", response)

if response:
response = self.serializer.deserialize(response)
Expand Down Expand Up @@ -281,7 +271,6 @@ def trigger_next_link(self, response, method_name="execute_complete") -> None:
url=url,
query_parameters=query_parameters,
response_type=self.response_type,
response_handler=self.response_handler,
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
Expand Down
118 changes: 66 additions & 52 deletions airflow/providers/microsoft/azure/sensors/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,32 @@
# under the License.
from __future__ import annotations

import asyncio
import json
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger

if TYPE_CHECKING:
from datetime import timedelta
from io import BytesIO

from kiota_abstractions.request_information import QueryParams
from kiota_abstractions.response_handler import NativeResponseType
from kiota_abstractions.serialization import ParsableFactory
from kiota_http.httpx_request_adapter import ResponseType
from msgraph_core import APIVersion

from airflow.triggers.base import TriggerEvent
from airflow.utils.context import Context


def default_event_processor(context: Context, event: TriggerEvent) -> bool:
if event.payload["status"] == "success":
return json.loads(event.payload["response"])["status"] == "Succeeded"
return False


class MSGraphSensor(BaseSensorOperator):
"""
A Microsoft Graph API sensor which allows you to poll an async REST call to the Microsoft Graph API.

:param url: The url being executed on the Microsoft Graph API (templated).
:param response_type: The expected return type of the response as a string. Possible value are: `bytes`,
`str`, `int`, `float`, `bool` and `datetime` (default is None).
:param response_handler: Function to convert the native HTTPX response returned by the hook (default is
lambda response, error_map: response.json()). The default expression will convert the native response
to JSON. If response_type parameter is specified, then the response_handler will be ignored.
:param method: The HTTP method being used to do the REST call (default is GET).
:param conn_id: The HTTP Connection ID to run the operator against (templated).
:param proxies: A dict defining the HTTP proxies to be used (default is None).
Expand Down Expand Up @@ -85,9 +74,6 @@ def __init__(
self,
url: str,
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
Expand All @@ -97,15 +83,15 @@ def __init__(
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
proxies: dict | None = None,
api_version: APIVersion | None = None,
event_processor: Callable[[Context, TriggerEvent], bool] = default_event_processor,
event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded",
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
serializer: type[ResponseSerializer] = ResponseSerializer,
retry_delay: timedelta | float = 60,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(retry_delay=retry_delay, **kwargs)
self.url = url
self.response_type = response_type
self.response_handler = response_handler
self.path_parameters = path_parameters
self.url_template = url_template
self.method = method
Expand All @@ -119,45 +105,73 @@ def __init__(
self.result_processor = result_processor
self.serializer = serializer()

@property
def trigger(self):
return MSGraphTrigger(
url=self.url,
response_type=self.response_type,
response_handler=self.response_handler,
path_parameters=self.path_parameters,
url_template=self.url_template,
method=self.method,
query_parameters=self.query_parameters,
headers=self.headers,
data=self.data,
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
api_version=self.api_version,
serializer=type(self.serializer),
def execute(self, context: Context):
self.defer(
trigger=MSGraphTrigger(
url=self.url,
response_type=self.response_type,
path_parameters=self.path_parameters,
url_template=self.url_template,
method=self.method,
query_parameters=self.query_parameters,
headers=self.headers,
data=self.data,
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
api_version=self.api_version,
serializer=type(self.serializer),
),
method_name=self.execute_complete.__name__,
)

async def async_poke(self, context: Context) -> bool | PokeReturnValue:
self.log.info("Sensor triggered")
def retry_execute(
self,
context: Context,
) -> Any:
self.execute(context=context)

def execute_complete(
self,
context: Context,
event: dict[Any, Any] | None = None,
) -> Any:
"""
Execute callback when MSGraphSensor finishes execution.

This method gets executed automatically when MSGraphTrigger completes its execution.
"""
self.log.debug("context: %s", context)

if event:
self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event)

if event.get("status") == "failure":
raise AirflowException(event.get("message"))

response = event.get("response")

self.log.debug("response: %s", response)

async for event in self.trigger.run():
self.log.debug("event: %s", event)
if response:
response = self.serializer.deserialize(response)

is_done = self.event_processor(context, event)
self.log.debug("deserialize response: %s", response)

self.log.debug("is_done: %s", is_done)
is_done = self.event_processor(context, response)

response = self.serializer.deserialize(event.payload["response"])
self.log.debug("is_done: %s", is_done)

self.log.debug("deserialize event: %s", response)
if is_done:
result = self.result_processor(context, response)

result = self.result_processor(context, response)
self.log.debug("processed response: %s", result)

self.log.debug("result: %s", result)
return result

return PokeReturnValue(is_done=is_done, xcom_value=result)
return PokeReturnValue(is_done=True)
self.defer(
trigger=TimeDeltaTrigger(self.retry_delay),
method_name=self.retry_execute.__name__,
)

def poke(self, context) -> bool | PokeReturnValue:
return asyncio.run(self.async_poke(context))
return None
Loading