@@ -69,11 +69,11 @@ class AuthHttpServer:
6969 def __init__ (
7070 self ,
7171 uri : str ,
72- redirect_uri : str ,
7372 buf_size : int = 16384 ,
73+ redirect_uri : str | None = None ,
7474 ) -> None :
7575 parsed_uri = urllib .parse .urlparse (uri )
76- parsed_redirect = urllib .parse .urlparse (redirect_uri )
76+ parsed_redirect = urllib .parse .urlparse (redirect_uri ) if redirect_uri else None
7777 self ._socket = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
7878 self .buf_size = buf_size
7979 if os .getenv ("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT" , "False" ).lower () == "true" :
@@ -84,10 +84,11 @@ def __init__(
8484 else :
8585 self ._socket .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEPORT , 1 )
8686
87- if parsed_redirect . hostname in ( "localhost" , "127.0.0.1" ):
87+ if parsed_redirect and self . _is_local_uri ( parsed_redirect ):
8888 port = parsed_redirect .port or 0
8989 else :
90- port = parsed_uri .port or 0
90+ port = parsed_uri .port if parsed_uri and parsed_uri .port else 0
91+
9192 for attempt in range (1 , self .DEFAULT_MAX_ATTEMPTS + 1 ):
9293 try :
9394 self ._socket .bind (
@@ -128,27 +129,30 @@ def __init__(
128129 query = parsed_uri .query ,
129130 fragment = parsed_uri .fragment ,
130131 )
131- if (
132- parsed_redirect .hostname in ("localhost" , "127.0.0.1" )
133- and port != parsed_redirect .port
134- ):
135- logger .debug (
136- f"Updating redirect port { parsed_redirect .port } to match the server port { port } ."
137- )
138- self ._redirect_uri = urllib .parse .ParseResult (
139- scheme = parsed_redirect .scheme ,
140- netloc = parsed_redirect .hostname + ":" + str (port ),
141- path = parsed_redirect .path ,
142- params = parsed_redirect .params ,
143- query = parsed_redirect .query ,
144- fragment = parsed_redirect .fragment ,
145- )
146- else :
147- self ._redirect_uri = parsed_redirect
132+ if parsed_redirect :
133+ if self ._is_local_uri (parsed_redirect ) and port != parsed_redirect .port :
134+ logger .debug (
135+ f"Updating redirect port { parsed_redirect .port } to match the server port { port } ."
136+ )
137+ self ._redirect_uri = urllib .parse .ParseResult (
138+ scheme = parsed_redirect .scheme ,
139+ netloc = parsed_redirect .hostname + ":" + str (port ),
140+ path = parsed_redirect .path ,
141+ params = parsed_redirect .params ,
142+ query = parsed_redirect .query ,
143+ fragment = parsed_redirect .fragment ,
144+ )
145+ else :
146+ self ._redirect_uri = parsed_redirect
147+
148+ def _is_local_uri (self , parsed_redirect ):
149+ return parsed_redirect .hostname in ("localhost" , "127.0.0.1" )
148150
149151 @property
150- def redirect_uri (self ) -> str :
151- return self ._redirect_uri .geturl ()
152+ def redirect_uri (self ) -> str | None :
153+ if self ._redirect_uri :
154+ return self ._redirect_uri .geturl ()
155+ return self .url
152156
153157 @property
154158 def url (self ) -> str :
0 commit comments