Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions airflow/providers/microsoft/azure/operators/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
timeout: float | None = None,
proxies: dict | None = None,
api_version: APIVersion | str | None = None,
pagination_function: Callable[[MSGraphAsyncOperator, dict], tuple[str, dict]] | None = None,
pagination_function: Callable[[MSGraphAsyncOperator, dict, Context], tuple[str, dict]] | None = None,
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
serializer: type[ResponseSerializer] = ResponseSerializer,
**kwargs: Any,
Expand All @@ -122,7 +122,6 @@ def __init__(
self.pagination_function = pagination_function or self.paginate
self.result_processor = result_processor
self.serializer: ResponseSerializer = serializer()
self.results: list[Any] | None = None

def execute(self, context: Context) -> None:
self.defer(
Expand Down Expand Up @@ -166,6 +165,8 @@ def execute_complete(

self.log.debug("response: %s", response)

results = self.pull_xcom(context=context)

if response:
response = self.serializer.deserialize(response)

Expand All @@ -178,39 +179,46 @@ def execute_complete(
event["response"] = result

try:
self.trigger_next_link(response=response, method_name=self.execute_complete.__name__)
self.trigger_next_link(
response=response, method_name=self.execute_complete.__name__, context=context
)
except TaskDeferred as exception:
self.results = self.pull_xcom(context=context)
self.append_result(
results=results,
result=result,
append_result_as_list_if_absent=True,
)
self.push_xcom(context=context, value=self.results)
self.push_xcom(context=context, value=results)
raise exception

self.append_result(result=result)
if not results:
return result

return self.results
self.append_result(results=results, result=result)
return results
return None

@classmethod
def append_result(
self,
cls,
results: list[Any],
result: Any,
append_result_as_list_if_absent: bool = False,
):
if isinstance(self.results, list):
) -> list[Any]:
if isinstance(results, list):
if isinstance(result, list):
self.results.extend(result)
results.extend(result)
else:
self.results.append(result)
results.append(result)
else:
if append_result_as_list_if_absent:
if isinstance(result, list):
self.results = result
return result
else:
self.results = [result]
return [result]
else:
self.results = result
return result
return results

def pull_xcom(self, context: Context) -> list:
map_index = context["ti"].map_index
Expand Down Expand Up @@ -251,27 +259,25 @@ def push_xcom(self, context: Context, value) -> None:
self.xcom_push(context=context, key=self.key, value=value)

@staticmethod
def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]:
def paginate(
operator: MSGraphAsyncOperator, response: dict, context: Context
) -> tuple[Any, dict[str, Any] | None]:
odata_count = response.get("@odata.count")
if odata_count and operator.query_parameters:
query_parameters = deepcopy(operator.query_parameters)
top = query_parameters.get("$top")
odata_count = response.get("@odata.count")

if top and odata_count:
if len(response.get("value", [])) == top:
skip = (
sum(map(lambda result: len(result["value"]), operator.results)) + top
if operator.results
else top
)
if len(response.get("value", [])) == top and context:
results = operator.pull_xcom(context=context)
skip = sum(map(lambda result: len(result["value"]), results)) + top if results else top
query_parameters["$skip"] = skip
return operator.url, query_parameters
return response.get("@odata.nextLink"), operator.query_parameters

def trigger_next_link(self, response, method_name="execute_complete") -> None:
def trigger_next_link(self, response, method_name: str, context: Context) -> None:
if isinstance(response, dict):
url, query_parameters = self.pagination_function(self, response)
url, query_parameters = self.pagination_function(self, response, context)

self.log.debug("url: %s", url)
self.log.debug("query_parameters: %s", query_parameters)
Expand Down
36 changes: 35 additions & 1 deletion tests/providers/microsoft/azure/operators/test_msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator
from airflow.triggers.base import TriggerEvent
from tests.providers.microsoft.azure.base import Base
from tests.providers.microsoft.conftest import load_file, load_json, mock_json_response, mock_response
from tests.providers.microsoft.conftest import (
load_file,
load_json,
mock_context,
mock_json_response,
mock_response,
)


class TestMSGraphAsyncOperator(Base):
Expand Down Expand Up @@ -127,3 +133,31 @@ def test_template_fields(self):

for template_field in MSGraphAsyncOperator.template_fields:
getattr(operator, template_field)

def test_paginate_without_query_parameters(self):
operator = MSGraphAsyncOperator(
task_id="user_license_details",
conn_id="msgraph_api",
url="users",
)
context = mock_context(task=operator)
response = load_json("resources", "users.json")
next_link, query_parameters = MSGraphAsyncOperator.paginate(operator, response, context)

assert next_link == response["@odata.nextLink"]
assert query_parameters is None

def test_paginate_with_context_query_parameters(self):
operator = MSGraphAsyncOperator(
task_id="user_license_details",
conn_id="msgraph_api",
url="users",
query_parameters={"$top": 12},
)
context = mock_context(task=operator)
response = load_json("resources", "users.json")
response["@odata.count"] = 100
url, query_parameters = MSGraphAsyncOperator.paginate(operator, response, context)

assert url == "users"
assert query_parameters == {"$skip": 12, "$top": 12}