Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import httpx
from azure.identity import ClientSecretCredential
from httpx import Timeout
from httpx import AsyncHTTPTransport, Timeout
from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
Expand Down Expand Up @@ -208,9 +208,9 @@ def format_no_proxy_url(url: str) -> str:
def to_httpx_proxies(cls, proxies: dict) -> dict:
proxies = proxies.copy()
if proxies.get("http"):
proxies["http://"] = proxies.pop("http")
proxies["http://"] = AsyncHTTPTransport(proxy=proxies.pop("http"))
if proxies.get("https"):
proxies["https://"] = proxies.pop("https")
proxies["https://"] = AsyncHTTPTransport(proxy=proxies.pop("https"))
if proxies.get("no"):
for url in proxies.pop("no", "").split(","):
proxies[cls.format_no_proxy_url(url.strip())] = None
Expand Down Expand Up @@ -288,7 +288,7 @@ def get_conn(self) -> RequestAdapter:
http_client = GraphClientFactory.create_with_default_middleware(
api_version=api_version, # type: ignore
client=httpx.AsyncClient(
proxy=httpx_proxies, # type: ignore
mounts=httpx_proxies,
timeout=Timeout(timeout=self.timeout),
verify=verify,
trust_env=trust_env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def execute(self, context: Context):
def retry_execute(
self,
context: Context,
**kwargs,
) -> Any:
self.execute(context=context)

Expand Down
2 changes: 1 addition & 1 deletion providers/tests/microsoft/azure/resources/status.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "Succeeded"}
[{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "InProgress"},{"id": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef", "createdDateTime": "2024-04-10T15:05:17.357", "status": "Succeeded"}]
27 changes: 21 additions & 6 deletions providers/tests/microsoft/azure/sensors/test_msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import json
from datetime import datetime

import pytest

Expand All @@ -31,7 +32,7 @@
class TestMSGraphSensor(Base):
def test_execute(self):
status = load_json("resources", "status.json")
response = mock_json_response(200, status)
response = mock_json_response(200, *status)

with self.patch_hook_and_request_adapter(response):
sensor = MSGraphSensor(
Expand All @@ -40,6 +41,7 @@ def test_execute(self):
url="myorg/admin/workspaces/scanStatus/{scanId}",
path_parameters={"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"},
result_processor=lambda context, result: result["id"],
retry_delay=5,
timeout=350.0,
)

Expand All @@ -48,16 +50,22 @@ def test_execute(self):
assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"}
assert isinstance(results, str)
assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
assert len(events) == 1
assert len(events) == 3
assert isinstance(events[0], TriggerEvent)
assert events[0].payload["status"] == "success"
assert events[0].payload["type"] == "builtins.dict"
assert events[0].payload["response"] == json.dumps(status)
assert events[0].payload["response"] == json.dumps(status[0])
assert isinstance(events[1], TriggerEvent)
assert isinstance(events[1].payload, datetime)
assert isinstance(events[2], TriggerEvent)
assert events[2].payload["status"] == "success"
assert events[2].payload["type"] == "builtins.dict"
assert events[2].payload["response"] == json.dumps(status[1])

@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters works in Airflow >= 2.10.0")
def test_execute_with_lambda_parameter(self):
status = load_json("resources", "status.json")
response = mock_json_response(200, status)
response = mock_json_response(200, *status)

with self.patch_hook_and_request_adapter(response):
sensor = MSGraphSensor(
Expand All @@ -66,6 +74,7 @@ def test_execute_with_lambda_parameter(self):
url="myorg/admin/workspaces/scanStatus/{scanId}",
path_parameters=lambda context, jinja_env: {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"},
result_processor=lambda context, result: result["id"],
retry_delay=5,
timeout=350.0,
)

Expand All @@ -74,11 +83,17 @@ def test_execute_with_lambda_parameter(self):
assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"}
assert isinstance(results, str)
assert results == "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"
assert len(events) == 1
assert len(events) == 3
assert isinstance(events[0], TriggerEvent)
assert events[0].payload["status"] == "success"
assert events[0].payload["type"] == "builtins.dict"
assert events[0].payload["response"] == json.dumps(status)
assert events[0].payload["response"] == json.dumps(status[0])
assert isinstance(events[1], TriggerEvent)
assert isinstance(events[1].payload, datetime)
assert isinstance(events[2], TriggerEvent)
assert events[2].payload["status"] == "success"
assert events[2].payload["type"] == "builtins.dict"
assert events[2].payload["response"] == json.dumps(status[1])

def test_template_fields(self):
sensor = MSGraphSensor(
Expand Down
6 changes: 4 additions & 2 deletions providers/tests/microsoft/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ def xcom_pull(
run_id: str | None = None,
) -> Any:
if map_indexes:
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}")
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}")
return values.get(
f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default
)
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default)

def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None:
values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value
Expand Down