Skip to content

Commit 1a0290d

Browse files
SNOW-2194055: Separate server and redirect URIs in AuthHttpServer (#2609)
Co-authored-by: Nikita Peshkov <nikita.peshkov@teampicnic.com>
1 parent c14dff9 commit 1a0290d

File tree

7 files changed

+324
-14
lines changed

7 files changed

+324
-14
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1414
- Pin lower versions of dependencies to oldest version without vulnerabilities.
1515
- Added no_proxy parameter for proxy configuration without using environmental variables.
1616
- Added OAUTH_AUTHORIZATION_CODE and OAUTH_CLIENT_CREDENTIALS to list of authenticators that don't require user to be set
17+
- Added `oauth_socket_uri` connection parameter allowing to separate server and redirect URIs for local OAuth server.
1718

1819
- v4.0.0(October 09,2025)
1920
- Added support for checking certificates revocation using revocation lists (CRLs)

src/snowflake/connector/auth/_http_server.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ def __init__(
7070
self,
7171
uri: str,
7272
buf_size: int = 16384,
73+
redirect_uri: str | None = None,
7374
) -> None:
7475
parsed_uri = urllib.parse.urlparse(uri)
76+
parsed_redirect = urllib.parse.urlparse(redirect_uri) if redirect_uri else None
7577
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
7678
self.buf_size = buf_size
7779
if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true":
@@ -82,30 +84,34 @@ def __init__(
8284
else:
8385
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
8486

85-
port = parsed_uri.port or 0
87+
if parsed_redirect and self._is_local_uri(parsed_redirect):
88+
server_port = parsed_redirect.port or 0
89+
else:
90+
server_port = parsed_uri.port or 0
91+
8692
for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1):
8793
try:
8894
self._socket.bind(
8995
(
9096
parsed_uri.hostname,
91-
port,
97+
server_port,
9298
)
9399
)
94100
break
95101
except socket.gaierror as ex:
96102
logger.error(
97-
f"Failed to bind authorization callback server to port {port}: {ex}"
103+
f"Failed to bind authorization callback server to port {server_port}: {ex}"
98104
)
99105
raise
100106
except OSError as ex:
101107
if attempt == self.DEFAULT_MAX_ATTEMPTS:
102108
logger.error(
103-
f"Failed to bind authorization callback server to port {port}: {ex}"
109+
f"Failed to bind authorization callback server to port {server_port}: {ex}"
104110
)
105111
raise
106112
logger.warning(
107113
f"Attempt {attempt}/{self.DEFAULT_MAX_ATTEMPTS}. "
108-
f"Failed to bind authorization callback server to port {port}: {ex}"
114+
f"Failed to bind authorization callback server to port {server_port}: {ex}"
109115
)
110116
time.sleep(self.PORT_BIND_TIMEOUT / self.PORT_BIND_MAX_ATTEMPTS)
111117
try:
@@ -114,16 +120,47 @@ def __init__(
114120
logger.error(f"Failed to start listening for auth callback: {ex}")
115121
self.close()
116122
raise
117-
port = self._socket.getsockname()[1]
123+
124+
server_port = self._socket.getsockname()[1]
118125
self._uri = urllib.parse.ParseResult(
119126
scheme=parsed_uri.scheme,
120-
netloc=parsed_uri.hostname + ":" + str(port),
127+
netloc=parsed_uri.hostname + ":" + str(server_port),
121128
path=parsed_uri.path,
122129
params=parsed_uri.params,
123130
query=parsed_uri.query,
124131
fragment=parsed_uri.fragment,
125132
)
126133

134+
if parsed_redirect:
135+
if (
136+
self._is_local_uri(parsed_redirect)
137+
and server_port != parsed_redirect.port
138+
):
139+
logger.debug(
140+
f"Updating redirect port {parsed_redirect.port} to match the server port {server_port}."
141+
)
142+
self._redirect_uri = urllib.parse.ParseResult(
143+
scheme=parsed_redirect.scheme,
144+
netloc=parsed_redirect.hostname + ":" + str(server_port),
145+
path=parsed_redirect.path,
146+
params=parsed_redirect.params,
147+
query=parsed_redirect.query,
148+
fragment=parsed_redirect.fragment,
149+
)
150+
else:
151+
self._redirect_uri = parsed_redirect
152+
else:
153+
# For backwards compatibility
154+
self._redirect_uri = self._uri
155+
156+
@staticmethod
157+
def _is_local_uri(uri):
158+
return uri.hostname in ("localhost", "127.0.0.1")
159+
160+
@property
161+
def redirect_uri(self) -> str | None:
162+
return self._redirect_uri.geturl()
163+
127164
@property
128165
def url(self) -> str:
129166
return self._uri.geturl()

src/snowflake/connector/auth/oauth_code.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
external_browser_timeout: int | None = None,
6666
enable_single_use_refresh_tokens: bool = False,
6767
connection: SnowflakeConnection | None = None,
68+
uri: str | None = None,
6869
**kwargs,
6970
) -> None:
7071
authentication_url, redirect_uri = self._validate_oauth_code_uris(
@@ -92,6 +93,7 @@ def __init__(
9293
self._origin: str | None = None
9394
self._authentication_url = authentication_url
9495
self._redirect_uri = redirect_uri
96+
self._uri = uri
9597
self._state = secrets.token_urlsafe(43)
9698
logger.debug("chose oauth state: %s", "".join("*" for _ in self._state))
9799
self._protocol = "http"
@@ -117,7 +119,10 @@ def _request_tokens(
117119
) -> (str | None, str | None):
118120
"""Web Browser based Authentication."""
119121
logger.debug("authenticating with OAuth authorization code flow")
120-
with AuthHttpServer(self._redirect_uri) as callback_server:
122+
with AuthHttpServer(
123+
redirect_uri=self._redirect_uri,
124+
uri=self._uri or self._redirect_uri, # for backward compatibility
125+
) as callback_server:
121126
code = self._do_authorization_request(callback_server, conn)
122127
return self._do_token_request(code, callback_server, conn)
123128

@@ -260,7 +265,7 @@ def _do_authorization_request(
260265
connection: SnowflakeConnection,
261266
) -> str | None:
262267
authorization_request = self._construct_authorization_request(
263-
callback_server.url
268+
callback_server.redirect_uri
264269
)
265270
logger.debug("step 1: going to open authorization URL")
266271
print(
@@ -315,7 +320,7 @@ def _do_token_request(
315320
fields = {
316321
"grant_type": "authorization_code",
317322
"code": code,
318-
"redirect_uri": callback_server.url,
323+
"redirect_uri": callback_server.redirect_uri,
319324
}
320325
if self._enable_single_use_refresh_tokens:
321326
fields["enable_single_use_refresh_tokens"] = "true"

src/snowflake/connector/connection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,11 @@ def _get_private_bytes_from_file(
384384
# SNOW-1825621: OAUTH implementation
385385
),
386386
"oauth_redirect_uri": ("http://127.0.0.1", str),
387+
"oauth_socket_uri": (
388+
"http://127.0.0.1",
389+
str,
390+
# SNOW-2194055: Separate server and redirect URIs in AuthHttpServer
391+
),
387392
"oauth_scope": (
388393
"",
389394
str,
@@ -1458,6 +1463,7 @@ def __open_connection(self):
14581463
host=self.host, port=self.port
14591464
),
14601465
redirect_uri=self._oauth_redirect_uri,
1466+
uri=self._oauth_socket_uri,
14611467
scope=self._oauth_scope,
14621468
pkce_enabled=not self._oauth_disable_pkce,
14631469
token_cache=(

src/snowflake/connector/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,6 @@ class IterUnit(Enum):
442442
ENV_VAR_PARTNER = "SF_PARTNER"
443443
ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE"
444444

445-
446445
_DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"}
447446

448447
_CONNECTIVITY_ERR_MSG = (

test/unit/test_auth_callback_server.py

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> No
2424
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
2525
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
2626
test_response: requests.Response | None = None
27-
with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
27+
with AuthHttpServer(
28+
"http://127.0.0.1/test_request",
29+
) as callback_server:
2830

2931
def request_callback():
3032
nonlocal test_response
@@ -57,7 +59,155 @@ def request_callback():
5759
def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None:
5860
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
5961
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
60-
with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
62+
with AuthHttpServer(
63+
"http://127.0.0.1/test_request",
64+
) as callback_server:
6165
block, client_socket = callback_server.receive_block(timeout=timeout)
6266
assert block is None
6367
assert client_socket is None
68+
69+
70+
@pytest.mark.parametrize(
71+
"socket_host",
72+
[
73+
"127.0.0.1",
74+
"localhost",
75+
],
76+
)
77+
@pytest.mark.parametrize(
78+
"socket_port",
79+
[
80+
"",
81+
":0",
82+
":12345",
83+
],
84+
)
85+
@pytest.mark.parametrize(
86+
"redirect_host",
87+
[
88+
"127.0.0.1",
89+
"localhost",
90+
],
91+
)
92+
@pytest.mark.parametrize(
93+
"redirect_port",
94+
[
95+
"",
96+
":0",
97+
":12345",
98+
],
99+
)
100+
@pytest.mark.parametrize(
101+
"dontwait",
102+
["false", "true"],
103+
)
104+
@pytest.mark.parametrize("reuse_port", ["true", "false"])
105+
def test_auth_callback_server_updates_localhost_redirect_uri_port_to_match_socket_port(
106+
monkeypatch,
107+
socket_host,
108+
socket_port,
109+
redirect_host,
110+
redirect_port,
111+
dontwait,
112+
reuse_port,
113+
) -> None:
114+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
115+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
116+
with AuthHttpServer(
117+
uri=f"http://{socket_host}{socket_port}/test_request",
118+
redirect_uri=f"http://{redirect_host}{redirect_port}/test_request",
119+
) as callback_server:
120+
assert callback_server._redirect_uri.port == callback_server.port
121+
122+
123+
@pytest.mark.parametrize(
124+
"socket_host",
125+
[
126+
"127.0.0.1",
127+
"localhost",
128+
],
129+
)
130+
@pytest.mark.parametrize(
131+
"socket_port",
132+
[
133+
"",
134+
":0",
135+
":12345",
136+
],
137+
)
138+
@pytest.mark.parametrize(
139+
"redirect_host",
140+
[
141+
"127.0.0.1",
142+
"localhost",
143+
],
144+
)
145+
@pytest.mark.parametrize(
146+
"redirect_port",
147+
[
148+
54321,
149+
54320,
150+
],
151+
)
152+
@pytest.mark.parametrize(
153+
"dontwait",
154+
["false", "true"],
155+
)
156+
@pytest.mark.parametrize("reuse_port", ["true", "false"])
157+
def test_auth_callback_server_uses_redirect_uri_port_when_specified(
158+
monkeypatch,
159+
socket_host,
160+
socket_port,
161+
redirect_host,
162+
redirect_port,
163+
dontwait,
164+
reuse_port,
165+
) -> None:
166+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
167+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
168+
with AuthHttpServer(
169+
uri=f"http://{socket_host}{socket_port}/test_request",
170+
redirect_uri=f"http://{redirect_host}:{redirect_port}/test_request",
171+
) as callback_server:
172+
assert callback_server.port == redirect_port
173+
assert callback_server._redirect_uri.port == redirect_port
174+
175+
176+
@pytest.mark.parametrize(
177+
"socket_host",
178+
[
179+
"127.0.0.1",
180+
"localhost",
181+
],
182+
)
183+
@pytest.mark.parametrize(
184+
"socket_port",
185+
[
186+
"",
187+
":0",
188+
":12345",
189+
],
190+
)
191+
@pytest.mark.parametrize(
192+
"redirect_port",
193+
[
194+
"",
195+
":0",
196+
":12345",
197+
],
198+
)
199+
@pytest.mark.parametrize(
200+
"dontwait",
201+
["false", "true"],
202+
)
203+
@pytest.mark.parametrize("reuse_port", ["true", "false"])
204+
def test_auth_callback_server_does_not_updates_nonlocalhost_redirect_uri_port_to_match_socket_port(
205+
monkeypatch, socket_host, socket_port, redirect_port, dontwait, reuse_port
206+
) -> None:
207+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
208+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
209+
redirect_uri = f"http://not_localhost{redirect_port}/test_request"
210+
with AuthHttpServer(
211+
uri=f"http://{socket_host}{socket_port}/test_request", redirect_uri=redirect_uri
212+
) as callback_server:
213+
assert callback_server.redirect_uri == redirect_uri

0 commit comments

Comments
 (0)