Skip to content

Commit 7db0542

Browse files
fixup! fixup! Make the changes backward compatible
1 parent 35b6972 commit 7db0542

File tree

4 files changed

+55
-16
lines changed

4 files changed

+55
-16
lines changed

src/snowflake/connector/auth/oauth_code.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import hashlib
99
import json
1010
import logging
11-
import os
1211
import secrets
1312
import socket
1413
import time
@@ -17,11 +16,7 @@
1716
from typing import TYPE_CHECKING, Any
1817

1918
from ..compat import parse_qs, urlparse, urlsplit
20-
from ..constants import (
21-
ENV_VAR_OAUTH_SOCKET_ADDRESS,
22-
ENV_VAR_OAUTH_SOCKET_PORT,
23-
OAUTH_TYPE_AUTHORIZATION_CODE,
24-
)
19+
from ..constants import OAUTH_TYPE_AUTHORIZATION_CODE
2520
from ..errorcode import (
2621
ER_INVALID_VALUE,
2722
ER_OAUTH_CALLBACK_ERROR,
@@ -70,6 +65,7 @@ def __init__(
7065
external_browser_timeout: int | None = None,
7166
enable_single_use_refresh_tokens: bool = False,
7267
connection: SnowflakeConnection | None = None,
68+
uri: str | None = None,
7369
**kwargs,
7470
) -> None:
7571
authentication_url, redirect_uri = self._validate_oauth_code_uris(
@@ -97,6 +93,7 @@ def __init__(
9793
self._origin: str | None = None
9894
self._authentication_url = authentication_url
9995
self._redirect_uri = redirect_uri
96+
self._uri = uri
10097
self._state = secrets.token_urlsafe(43)
10198
logger.debug("chose oauth state: %s", "".join("*" for _ in self._state))
10299
self._protocol = "http"
@@ -124,18 +121,11 @@ def _request_tokens(
124121
logger.debug("authenticating with OAuth authorization code flow")
125122
with AuthHttpServer(
126123
redirect_uri=self._redirect_uri,
127-
uri=self._read_uri_from_env(),
124+
uri=self._uri or self._redirect_uri, # To preserve backward compatibility
128125
) as callback_server:
129126
code = self._do_authorization_request(callback_server, conn)
130127
return self._do_token_request(code, callback_server, conn)
131128

132-
def _read_uri_from_env(self) -> str:
133-
oauth_socket_address = os.getenv(
134-
ENV_VAR_OAUTH_SOCKET_ADDRESS, "http://localhost"
135-
)
136-
oauth_socket_port = os.getenv(ENV_VAR_OAUTH_SOCKET_PORT, "0")
137-
return f"{oauth_socket_address}:{oauth_socket_port}"
138-
139129
def _check_post_requested(
140130
self, data: list[str]
141131
) -> tuple[str, str] | tuple[None, None]:

src/snowflake/connector/connection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,11 @@ def _get_private_bytes_from_file(
384384
# SNOW-1825621: OAUTH implementation
385385
),
386386
"oauth_redirect_uri": ("http://127.0.0.1", str),
387+
"oauth_socket_uri": (
388+
"http://127.0.0.1",
389+
str,
390+
# SNOW-2194055: Separate server and redirect URIs in AuthHttpServer
391+
),
387392
"oauth_scope": (
388393
"",
389394
str,
@@ -1456,6 +1461,7 @@ def __open_connection(self):
14561461
host=self.host, port=self.port
14571462
),
14581463
redirect_uri=self._oauth_redirect_uri,
1464+
uri=self._oauth_socket_uri,
14591465
scope=self._oauth_scope,
14601466
pkce_enabled=not self._oauth_disable_pkce,
14611467
token_cache=(

src/snowflake/connector/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,6 @@ class IterUnit(Enum):
441441
# TODO: all env variables definitions should be here
442442
ENV_VAR_PARTNER = "SF_PARTNER"
443443
ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE"
444-
ENV_VAR_OAUTH_SOCKET_ADDRESS = "SNOWFLAKE_OAUTH_SOCKET_ADDRESS"
445-
ENV_VAR_OAUTH_SOCKET_PORT = "SNOWFLAKE_OAUTH_SOCKET_PORT"
446444

447445
_DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"}
448446

test/unit/test_auth_oauth_auth_code.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,48 @@ def mock_request_tokens(self, **kwargs):
394394
assert isinstance(conn.auth_class, AuthByOauthCode)
395395

396396
conn.close()
397+
398+
399+
@pytest.mark.parametrize(
400+
"uri,redirect_uri",
401+
[
402+
("https://example.com/server", "http://localhost:8080"),
403+
("http://localhost:8080", "https://example.com/redirect"),
404+
("http://127.0.0.1:9090", "https://server.com/oauth/callback"),
405+
(None, "https://redirect.example.com"),
406+
],
407+
)
408+
@mock.patch(
409+
"snowflake.connector.auth.oauth_code.AuthByOauthCode._do_authorization_request"
410+
)
411+
@mock.patch("snowflake.connector.auth.oauth_code.AuthByOauthCode._do_token_request")
412+
def test_auth_oauth_auth_code_passes_uri_to_http_server(
413+
_, __, uri, redirect_uri, omit_oauth_urls_check
414+
):
415+
"""Test that uri and redirect_uri parameters are passed correctly to AuthHttpServer."""
416+
auth = AuthByOauthCode(
417+
"app",
418+
"clientId",
419+
"clientSecret",
420+
"https://auth_url",
421+
"tokenRequestUrl",
422+
redirect_uri,
423+
"scope",
424+
"host",
425+
uri=uri,
426+
)
427+
428+
with patch(
429+
"snowflake.connector.auth.oauth_code.AuthHttpServer",
430+
# return_value=None,
431+
) as mock_http_server_init:
432+
auth._request_tokens(
433+
conn=mock.MagicMock(),
434+
authenticator="authenticator",
435+
service_name="service_name",
436+
account="account",
437+
user="user",
438+
)
439+
mock_http_server_init.assert_called_once_with(
440+
uri=uri, redirect_uri=redirect_uri
441+
)

0 commit comments

Comments
 (0)