Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import asyncio
import base64
import time
import uuid
Expand Down Expand Up @@ -78,17 +79,36 @@ class SnowflakeSqlApiHook(SnowflakeHook):
:param token_renewal_delta: Renewal time of the JWT Token in timedelta
:param deferrable: Run operator in the deferrable mode.
:param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
:param http_request_kwargs: Optional keyword arguments forwarded to ``requests.Session.request`` for synchronous HTTP calls.
Request-defining fields (e.g. ``method``, ``url``, ``headers``, ``params``, ``json``)
are owned by the hook and must not be provided here.

:param aiohttp_session_kwargs: Optional keyword arguments forwarded to
``aiohttp.ClientSession`` for asynchronous HTTP calls.
Session-owned fields like ``headers`` are managed by the hook
and must not be overridden here.

:param aiohttp_request_kwargs: Optional keyword arguments forwarded to
``aiohttp.ClientSession.request`` for asynchronous HTTP calls.
Request identity fields (e.g. ``method``, ``url``, ``headers``,
``params``) are owned by the hook and must not be overridden here.
"""

LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute lifetime
RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes
HTTP_REQUEST_KWARGS_GUARD_KEYS: set[str] = {"method", "url", "headers", "params", "json"}
AIOHTTP_SESSION_KWARGS_GUARD_KEYS: set[str] = {"headers"}
AIOHTTP_REQUEST_KWARGS_GUARD_KEYS: set[str] = {"method", "url", "params", "headers"}

def __init__(
self,
snowflake_conn_id: str,
token_life_time: timedelta = LIFETIME,
token_renewal_delta: timedelta = RENEWAL_DELTA,
api_retry_args: dict[Any, Any] | None = None, # Optional retry arguments passed to tenacity.retry
http_request_kwargs: dict[str, Any] | None = None,
aiohttp_session_kwargs: dict[str, Any] | None = None,
aiohttp_request_kwargs: dict[str, Any] | None = None,
*args: Any,
**kwargs: Any,
):
Expand All @@ -109,6 +129,10 @@ def __init__(
if api_retry_args:
self.retry_config.update(api_retry_args)

self.http_request_kwargs = http_request_kwargs or {}
self.aiohttp_session_kwargs = aiohttp_session_kwargs or {}
self.aiohttp_request_kwargs = aiohttp_request_kwargs or {}

def get_private_key(self) -> None:
"""Get the private key from snowflake connection."""
conn = self.get_connection(self.snowflake_conn_id)
Expand Down Expand Up @@ -447,7 +471,7 @@ def _should_retry_on_error(exception) -> bool:
return exception.status in [429, 503, 504]
if isinstance(
exception,
ConnectionError | Timeout | ClientConnectionError,
ConnectionError | Timeout | ClientConnectionError | asyncio.TimeoutError,
):
return True
return False
Expand All @@ -468,15 +492,30 @@ def _make_api_call_with_retries(
:param json: (Optional) The data to include in the API call.
:return: The response object from the API call.
"""
if method.upper() not in ("GET", "POST"):
raise ValueError(f"Unsupported HTTP method: {method}")

user_kwargs: dict[str, Any] = dict(self.http_request_kwargs or {})
forbidden: set[str] = self.HTTP_REQUEST_KWARGS_GUARD_KEYS & set(user_kwargs)
if forbidden:
raise ValueError(
f"http_request_kwargs must not override request identity fields: {sorted(forbidden)}"
)

with requests.Session() as session:
for attempt in Retrying(**self.retry_config): # type: ignore
with attempt:
if method.upper() in ("GET", "POST"):
response = session.request(
method=method.lower(), url=url, headers=headers, params=params, json=json
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
base_request_kwargs: dict[str, Any] = {
"method": method.lower(),
"url": url,
"headers": headers,
"params": params,
"json": json,
}
# Order is important
# user first, base second => base wins even if guard misses something
request_kwargs: dict[str, Any] = {**user_kwargs, **base_request_kwargs}
response = session.request(**request_kwargs)
response.raise_for_status()
return response.status_code, response.json()

Expand All @@ -493,14 +532,35 @@ async def _make_api_call_with_retries_async(self, method, url, headers, params=N
:param params: (Optional) The query parameters to include in the API call.
:return: The response object from the API call.
"""
async with aiohttp.ClientSession(headers=headers) as session:
if method.upper() != "GET":
raise ValueError(f"Unsupported HTTP method: {method}")

user_session_kwargs: dict[str, Any] = dict(self.aiohttp_session_kwargs or {})
forbidden = self.AIOHTTP_SESSION_KWARGS_GUARD_KEYS & set(user_session_kwargs)
if forbidden:
raise ValueError(
f"aiohttp_session_kwargs must not override session-owned fields: {sorted(forbidden)}"
)

user_request_kwargs: dict[str, Any] = dict(self.aiohttp_request_kwargs or {})
forbidden = self.AIOHTTP_REQUEST_KWARGS_GUARD_KEYS & set(user_request_kwargs)
if forbidden:
raise ValueError(
f"aiohttp_request_kwargs must not override request identity fields: {sorted(forbidden)}"
)
base_session_kwargs: dict[str, Any] = {"headers": headers}
session_kwargs: dict[str, Any] = {**user_session_kwargs, **base_session_kwargs}
async with aiohttp.ClientSession(**session_kwargs) as session:
async for attempt in AsyncRetrying(**self.retry_config):
with attempt:
if method.upper() == "GET":
async with session.request(method=method.lower(), url=url, params=params) as response:
response.raise_for_status()
# Return status and json content for async processing
content = await response.json()
return response.status, content
else:
raise ValueError(f"Unsupported HTTP method: {method}")
base_request_kwargs: dict[str, Any] = {
"method": method.lower(),
"url": url,
"params": params,
}
request_kwargs: dict[str, Any] = {**user_request_kwargs, **base_request_kwargs}
async with session.request(**request_kwargs) as response:
response.raise_for_status()
# Return status and json content for async processing
content = await response.json()
return response.status, content
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import asyncio
import base64
import unittest
import uuid
Expand Down Expand Up @@ -1471,3 +1472,142 @@ def test_cancel_queries(self, mock_cancel_execution):

assert mock_cancel_execution.call_count == 3
mock_cancel_execution.assert_has_calls([call("query-1"), call("query-2"), call("query-3")])

def test_make_api_call_passes_timeout_to_requests(self, mock_requests):
"""Test that http_request_kwargs are forwarded to requests.request()."""
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn", http_request_kwargs={"timeout": 12.0})
resp = mock.MagicMock()
resp.status_code = 200
resp.raise_for_status.return_value = None
resp.json.return_value = {"ok": True}
mock_requests.request.return_value = resp

hook._make_api_call_with_retries("GET", API_URL, HEADERS)

mock_requests.request.assert_called_once_with(
method="get",
url=API_URL,
headers=HEADERS,
params=None,
json=None,
timeout=12.0,
)

@pytest.mark.parametrize("forbidden_key", ["method", "url", "headers", "params", "json"])
def test_make_api_call_with_retries_rejects_http_request_kwargs_overriding_identity_fields(
self,
forbidden_key: str,
):
"""
Test http_request_kwargs cannot override request identity fields.
The hook owns request-defining fields such as method, url, headers, params, and json.
Supplying any of these via http_request_kwargs must fail fast with a ValueError.
"""
hook = SnowflakeSqlApiHook(
snowflake_conn_id="test_conn",
http_request_kwargs={forbidden_key: "boom"},
)

with pytest.raises(
ValueError,
match=r"http_request_kwargs must not override request identity fields",
):
hook._make_api_call_with_retries("GET", API_URL, HEADERS)

@pytest.mark.asyncio
async def test_make_api_call_with_retries_async_passes_timeout_to_clientsession(self):
"""
Test that aiohttp_session_kwargs are forwarded to aiohttp.ClientSession.
"""
hook = SnowflakeSqlApiHook(
snowflake_conn_id="test_conn",
aiohttp_session_kwargs={"timeout": aiohttp.ClientTimeout(total=7.0)},
)

with mock.patch(f"{MODULE_PATH}.aiohttp.ClientSession") as client_session_cls:
session_cm = mock.MagicMock()
client_session_cls.return_value.__aenter__ = AsyncMock(return_value=session_cm)

req_cm = mock.MagicMock()
session_cm.request.return_value = req_cm

resp = mock.MagicMock()
resp.status = 200
resp.raise_for_status.return_value = None
resp.json = AsyncMock(return_value=GET_RESPONSE)
req_cm.__aenter__ = AsyncMock(return_value=resp)

await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS)

_, kwargs = client_session_cls.call_args
timeout_obj = kwargs["timeout"]
assert isinstance(timeout_obj, aiohttp.ClientTimeout)
assert timeout_obj.total == 7.0

@pytest.mark.asyncio
async def test_make_api_call_with_retries_async_retries_on_timeout_error(self, mock_async_request):
"""
Test that the async API call is retried when a timeout error occurs.

The first request raises asyncio.TimeoutError, and the second attempt succeeds.
This ensures retry behavior is correctly applied to transient async failures.
"""
hook = SnowflakeSqlApiHook(
snowflake_conn_id="test_conn",
aiohttp_session_kwargs={"timeout": aiohttp.ClientTimeout(total=7.0)},
)

mock_async_request.__aenter__.side_effect = [
asyncio.TimeoutError(),
create_async_request_client_response_success(json=GET_RESPONSE, status_code=200),
]

status, data = await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS)

assert status == 200
assert data == GET_RESPONSE
assert mock_async_request.__aenter__.call_count == 2

@pytest.mark.asyncio
@pytest.mark.parametrize("forbidden_key", ["headers"])
async def test_make_api_call_with_retries_async_rejects_aiohttp_session_kwargs_overriding_session_owned_fields(
self, forbidden_key
):
"""
Test aiohttp_session_kwargs cannot override session-owned fields.
Session-owned fields such as headers are managed by the hook and
must not be overridden via aiohttp_session_kwargs.
"""
hook = SnowflakeSqlApiHook(
snowflake_conn_id="test_conn",
aiohttp_session_kwargs={forbidden_key: {"x": "boom"}},
)

with pytest.raises(
ValueError,
match=r"aiohttp_session_kwargs must not override session-owned fields",
):
await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS)

@pytest.mark.asyncio
@pytest.mark.parametrize("forbidden_key", ["method", "url", "params", "headers"])
async def test_make_api_call_with_retries_async_rejects_aiohttp_request_kwargs_overriding_identity_fields(
self, forbidden_key
):
"""
Test aiohttp_request_kwargs cannot override request identity fields.

Request identity fields such as method, url, params, and headers
are owned by the hook and must not be overridden via
aiohttp_request_kwargs.
"""
hook = SnowflakeSqlApiHook(
snowflake_conn_id="test_conn",
aiohttp_request_kwargs={forbidden_key: "boom"},
)

with pytest.raises(
ValueError,
match=r"aiohttp_request_kwargs must not override request identity fields",
):
await hook._make_api_call_with_retries_async("GET", API_URL, HEADERS)