1111
1212import pytest
1313
14+ # Add cryptography imports for private key handling
15+ from cryptography .hazmat .backends import default_backend
16+ from cryptography .hazmat .primitives import serialization
17+ from cryptography .hazmat .primitives .serialization import (
18+ Encoding ,
19+ NoEncryption ,
20+ PrivateFormat ,
21+ )
22+
1423import snowflake .connector
1524from snowflake .connector .compat import IS_WINDOWS
1625from snowflake .connector .connection import DefaultConverterClass
2837 from snowflake .connector import SnowflakeConnection
2938
3039RUNNING_ON_GH = os .getenv ("GITHUB_ACTIONS" ) == "true"
40+ RUNNING_ON_JENKINS = os .getenv ("JENKINS_HOME" ) not in (None , "false" )
41+ RUNNING_OLD_DRIVER = os .getenv ("TOX_ENV_NAME" ) == "olddriver"
3142TEST_USING_VENDORED_ARROW = os .getenv ("TEST_USING_VENDORED_ARROW" ) == "true"
3243
44+
45+ def _get_private_key_bytes_for_olddriver (private_key_file : str ) -> bytes :
46+ """Load private key file and convert to DER format bytes for olddriver compatibility.
47+
48+ The olddriver expects private keys in DER format as bytes.
49+ This function handles both PEM and DER input formats.
50+ """
51+ with open (private_key_file , "rb" ) as key_file :
52+ key_data = key_file .read ()
53+
54+ # Try to load as PEM first, then DER
55+ try :
56+ # Try PEM format first
57+ private_key = serialization .load_pem_private_key (
58+ key_data ,
59+ password = None ,
60+ backend = default_backend (),
61+ )
62+ except ValueError :
63+ try :
64+ # Try DER format
65+ private_key = serialization .load_der_private_key (
66+ key_data ,
67+ password = None ,
68+ backend = default_backend (),
69+ )
70+ except ValueError as e :
71+ raise ValueError (f"Could not load private key from { private_key_file } : { e } " )
72+
73+ # Convert to DER format bytes as expected by olddriver
74+ return private_key .private_bytes (
75+ encoding = Encoding .DER ,
76+ format = PrivateFormat .PKCS8 ,
77+ encryption_algorithm = NoEncryption (),
78+ )
79+
80+
3381if not isinstance (CONNECTION_PARAMETERS ["host" ], str ):
3482 raise Exception ("default host is not a string in parameters.py" )
3583RUNNING_AGAINST_LOCAL_SNOWFLAKE = CONNECTION_PARAMETERS ["host" ].endswith ("local" )
@@ -72,16 +120,42 @@ def _get_worker_specific_schema():
72120 )
73121
74122
75- DEFAULT_PARAMETERS : dict [str , Any ] = {
76- "account" : "<account_name>" ,
77- "user" : "<user_name>" ,
78- "password" : "<password>" ,
79- "database" : "<database_name>" ,
80- "schema" : "<schema_name>" ,
81- "protocol" : "https" ,
82- "host" : "<host>" ,
83- "port" : "443" ,
84- }
123+ if RUNNING_ON_JENKINS :
124+ DEFAULT_PARAMETERS : dict [str , Any ] = {
125+ "account" : "<account_name>" ,
126+ "user" : "<user_name>" ,
127+ "password" : "<password>" ,
128+ "database" : "<database_name>" ,
129+ "schema" : "<schema_name>" ,
130+ "protocol" : "https" ,
131+ "host" : "<host>" ,
132+ "port" : "443" ,
133+ }
134+ else :
135+ if RUNNING_OLD_DRIVER :
136+ DEFAULT_PARAMETERS : dict [str , Any ] = {
137+ "account" : "<account_name>" ,
138+ "user" : "<user_name>" ,
139+ "database" : "<database_name>" ,
140+ "schema" : "<schema_name>" ,
141+ "protocol" : "https" ,
142+ "host" : "<host>" ,
143+ "port" : "443" ,
144+ "authenticator" : "SNOWFLAKE_JWT" ,
145+ "private_key_file" : "<private_key_file>" ,
146+ }
147+ else :
148+ DEFAULT_PARAMETERS : dict [str , Any ] = {
149+ "account" : "<account_name>" ,
150+ "user" : "<user_name>" ,
151+ "database" : "<database_name>" ,
152+ "schema" : "<schema_name>" ,
153+ "protocol" : "https" ,
154+ "host" : "<host>" ,
155+ "port" : "443" ,
156+ "authenticator" : "<authenticator>" ,
157+ "private_key_file" : "<private_key_file>" ,
158+ }
85159
86160
87161def print_help () -> None :
@@ -91,9 +165,10 @@ def print_help() -> None:
91165CONNECTION_PARAMETERS = {
92166 'account': 'testaccount',
93167 'user': 'user1',
94- 'password': 'test',
95168 'database': 'testdb',
96169 'schema': 'public',
170+ 'authenticator': 'KEY_PAIR_AUTHENTICATOR',
171+ 'private_key_file': '/path/to/private_key.p8',
97172}
98173"""
99174 )
@@ -196,15 +271,48 @@ def init_test_schema(db_parameters) -> Generator[None]:
196271
197272 This is automatically called per test session.
198273 """
199- connection_params = {
200- "user" : db_parameters ["user" ],
201- "password" : db_parameters ["password" ],
202- "host" : db_parameters ["host" ],
203- "port" : db_parameters ["port" ],
204- "database" : db_parameters ["database" ],
205- "account" : db_parameters ["account" ],
206- "protocol" : db_parameters ["protocol" ],
207- }
274+ if RUNNING_ON_JENKINS :
275+ connection_params = {
276+ "user" : db_parameters ["user" ],
277+ "password" : db_parameters ["password" ],
278+ "host" : db_parameters ["host" ],
279+ "port" : db_parameters ["port" ],
280+ "database" : db_parameters ["database" ],
281+ "account" : db_parameters ["account" ],
282+ "protocol" : db_parameters ["protocol" ],
283+ }
284+ else :
285+ connection_params = {
286+ "user" : db_parameters ["user" ],
287+ "host" : db_parameters ["host" ],
288+ "port" : db_parameters ["port" ],
289+ "database" : db_parameters ["database" ],
290+ "account" : db_parameters ["account" ],
291+ "protocol" : db_parameters ["protocol" ],
292+ }
293+
294+ # Handle private key authentication differently for old vs new driver
295+ if RUNNING_OLD_DRIVER :
296+ # Old driver expects private_key as bytes and SNOWFLAKE_JWT authenticator
297+ private_key_file = db_parameters .get ("private_key_file" )
298+ if private_key_file :
299+ private_key_bytes = _get_private_key_bytes_for_olddriver (
300+ private_key_file
301+ )
302+ connection_params .update (
303+ {
304+ "authenticator" : "SNOWFLAKE_JWT" ,
305+ "private_key" : private_key_bytes ,
306+ }
307+ )
308+ else :
309+ # New driver expects private_key_file and KEY_PAIR_AUTHENTICATOR
310+ connection_params .update (
311+ {
312+ "authenticator" : db_parameters ["authenticator" ],
313+ "private_key_file" : db_parameters ["private_key_file" ],
314+ }
315+ )
208316
209317 # Role may be needed when running on preprod, but is not present on Jenkins jobs
210318 optional_role = db_parameters .get ("role" )
@@ -226,6 +334,24 @@ def create_connection(connection_name: str, **kwargs) -> SnowflakeConnection:
226334 """
227335 ret = get_db_parameters (connection_name )
228336 ret .update (kwargs )
337+
338+ # Handle private key authentication differently for old vs new driver (only if not on Jenkins)
339+ if not RUNNING_ON_JENKINS and "private_key_file" in ret :
340+ if RUNNING_OLD_DRIVER :
341+ # Old driver (3.1.0) expects private_key as bytes and SNOWFLAKE_JWT authenticator
342+ private_key_file = ret .get ("private_key_file" )
343+ if (
344+ private_key_file and "private_key" not in ret
345+ ): # Don't override if private_key already set
346+ private_key_bytes = _get_private_key_bytes_for_olddriver (
347+ private_key_file
348+ )
349+ ret ["authenticator" ] = "SNOWFLAKE_JWT"
350+ ret ["private_key" ] = private_key_bytes
351+ ret .pop (
352+ "private_key_file" , None
353+ ) # Remove private_key_file for old driver
354+
229355 connection = snowflake .connector .connect (** ret )
230356 return connection
231357
0 commit comments