Skip to content

Commit e8fc36c

Browse files
[async] test fixes:
Fixed synch factory in async code of _wif.py Fixed async make_session and added necessary host_port_pooling
1 parent 7581edc commit e8fc36c

File tree

5 files changed

+78
-11
lines changed

5 files changed

+78
-11
lines changed

src/snowflake/connector/aio/_session_manager.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -540,17 +540,24 @@ def make_session(self, *, url: str | None = None) -> aiohttp.ClientSession:
540540
session_manager=self.clone(),
541541
snowflake_ocsp_mode=self._cfg.snowflake_ocsp_mode,
542542
)
543-
# We use requests.utils here (in asynch code) to keep the behaviour uniform for synch and asynch code. If we wanted each version to depict its http library's behaviour, we could use here: aiohttp.helpers.proxy_bypass(url, proxies={...}) here
544-
proxy = (
545-
None
546-
if should_bypass_proxies(url, no_proxy=self.config.no_proxy)
547-
else self.proxy_url
548-
)
543+
544+
proxy_from_conn_params: str | None = None
545+
if not aiohttp.helpers.proxies_from_env():
546+
# TODO: This is only needed because we want to keep compatibility with the synch driver version.
547+
# Otherwise, we could remove that condition and always pass proxy from conn params to the Session constructor.
548+
# But in such case precedence will be reverted and it will overwrite the env vars settings.
549+
550+
# We use requests.utils here (in asynch code) to keep the behaviour uniform for synch and asynch code. If we wanted each version to depict its http library's behaviour, we could use here: aiohttp.helpers.proxy_bypass(url, proxies={...}) here
551+
proxy_from_conn_params = (
552+
None
553+
if should_bypass_proxies(url, no_proxy=self.config.no_proxy)
554+
else self.proxy_url
555+
)
549556
# Construct session with base proxy set, request() may override per-URL when bypassing
550557
return self.SessionWithProxy(
551558
connector=connector,
552559
trust_env=self._cfg.trust_env,
553-
proxy=proxy,
560+
proxy=proxy_from_conn_params,
554561
)
555562

556563

src/snowflake/connector/aio/_wif_util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
1616
from ..errors import MissingDependencyError, ProgrammingError
17-
from ..session_manager import SessionManagerFactory
1817
from ..wif_util import (
1918
DEFAULT_ENTRA_SNOWFLAKE_RESOURCE,
2019
SNOWFLAKE_AUDIENCE,
@@ -24,7 +23,7 @@
2423
extract_iss_and_sub_without_signature_verification,
2524
get_aws_sts_hostname,
2625
)
27-
from ._session_manager import SessionManager
26+
from ._session_manager import SessionManager, SessionManagerFactory
2827

2928
logger = logging.getLogger(__name__)
3029

test/unit/aio/mock_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def forbidden_connect(*args, **kwargs):
3535

3636
class MockSessionManager(SessionManager):
3737
def make_session(self, *, url: str | None = None):
38-
session = super().make_session(url)
38+
session = super().make_session(url=url)
3939
if not allow_send:
4040
# Block at connector._connect level (like sync blocks session.send)
4141
# This allows patches on session.request to work

test/unit/aio/test_connection_async_unit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@ def test_connect_metadata_preservation():
912912
connect_doc == source_doc
913913
), "inspect.getdoc(connect) should match inspect.getdoc(SnowflakeConnection.__init__)"
914914

915-
# Test 8: Check that connect is callable and returns expected type
915+
# Test 8: Check that connect is callable
916916
assert callable(connect), "connect should be callable"
917917

918918
# Test 9: Check type() and __class__ values (important for user introspection)
@@ -933,3 +933,4 @@ def test_connect_metadata_preservation():
933933
assert (
934934
len(params) > 0
935935
), "connect should have parameters from SnowflakeConnection.__init__"
936+
# Should have parameters like account, user, password, etc.

test/unit/aio/test_proxies_async.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import urllib.request
34
from collections import deque
45
from test.unit.test_proxies import (
56
DbRequestFlags,
@@ -15,6 +16,9 @@
1516

1617
import aiohttp
1718
import pytest
19+
from aiohttp import BasicAuth
20+
from aiohttp.helpers import proxies_from_env
21+
from yarl import URL
1822

1923
from snowflake.connector.aio import connect as async_connect
2024

@@ -355,6 +359,57 @@ async def test_no_proxy_basic_param_proxy_bypass_backend(
355359
assert flags.proxy_saw_storage is True
356360

357361

362+
@pytest.fixture
363+
def fix_aiohttp_proxy_bypass(monkeypatch):
364+
"""Fix aiohttp's proxy bypass to check host:port instead of just host.
365+
366+
This fixture implements a two-step fix:
367+
1. Override get_env_proxy_for_url to use host_port_subcomponent for proxy_bypass
368+
2. Override urllib.request._splitport to return (host:port, port) for proper matching
369+
"""
370+
371+
# Step 1: Override get_env_proxy_for_url to pass host:port to proxy_bypass
372+
def get_env_proxy_for_url_with_port(url: URL) -> tuple[URL, BasicAuth | None]:
373+
"""Get a permitted proxy for the given URL from the env, checking host:port."""
374+
from urllib.request import proxy_bypass
375+
376+
# Check proxy bypass using host:port combination
377+
if url.host is not None:
378+
# Use host_port_subcomponent which includes port
379+
host_port = f"{url.host}:{url.port}" if url.port else url.host
380+
if proxy_bypass(host_port):
381+
raise LookupError(f"Proxying is disallowed for `{host_port!r}`")
382+
383+
proxies_in_env = proxies_from_env()
384+
try:
385+
proxy_info = proxies_in_env[url.scheme]
386+
except KeyError:
387+
raise LookupError(f"No proxies found for `{url!s}` in the env")
388+
else:
389+
return proxy_info.proxy, proxy_info.proxy_auth
390+
391+
# Step 2: Override _splitport to return host:port as first element
392+
original_splitport = urllib.request._splitport
393+
394+
def _splitport_with_port(host):
395+
"""Override to return (host:port, port) instead of (host, port)."""
396+
result = original_splitport(host)
397+
if result is None:
398+
return (host, None)
399+
host_only, port = result
400+
# If port was found, return the original host (with port) as first element
401+
if port is not None:
402+
return (host, port) # Return original host:port string
403+
return (host_only, port)
404+
405+
monkeypatch.setattr(
406+
aiohttp.client, "get_env_proxy_for_url", get_env_proxy_for_url_with_port
407+
)
408+
monkeypatch.setattr(urllib.request, "_splitport", _splitport_with_port)
409+
410+
yield
411+
412+
358413
@pytest.mark.skipolddriver
359414
@pytest.mark.parametrize("proxy_method", ["explicit_args", "env_vars"])
360415
@pytest.mark.parametrize("no_proxy_source", ["param", "env"])
@@ -366,6 +421,7 @@ async def test_no_proxy_source_vs_proxy_method_matrix(
366421
proxy_method,
367422
no_proxy_source,
368423
host_port_pooling,
424+
fix_aiohttp_proxy_bypass,
369425
):
370426
if proxy_method == "env_vars" and no_proxy_source == "param":
371427
pytest.xfail(
@@ -407,6 +463,7 @@ async def test_no_proxy_backend_matrix(
407463
proxy_method,
408464
no_proxy_source,
409465
host_port_pooling,
466+
fix_aiohttp_proxy_bypass,
410467
):
411468
if proxy_method == "env_vars" and no_proxy_source == "param":
412469
pytest.xfail(
@@ -461,6 +518,7 @@ async def test_no_proxy_multiple_values_param_only(
461518
wiremock_mapping_dir,
462519
proxy_env_vars,
463520
no_proxy_factory,
521+
host_port_pooling, # Unlike in synch code - Session stores no_proxy setup so it would be reused for proxy and backend since they are both on localhost
464522
):
465523
target_wm, storage_wm, proxy_wm = wiremock_backend_storage_proxy
466524
_setup_backend_storage_mappings(
@@ -569,6 +627,8 @@ async def test_proxy_env_vars_take_precedence_over_connection_params(
569627
wiremock_generic_mappings_dir,
570628
proxy_env_vars,
571629
monkeypatch,
630+
host_port_pooling,
631+
fix_aiohttp_proxy_bypass,
572632
):
573633
"""Verify that proxy_host/proxy_port connection parameters take precedence over env vars.
574634

0 commit comments

Comments
 (0)