Skip to content

Commit 2c61a41

Browse files
committed
Resolving comments
1 parent 992aa6d commit 2c61a41

File tree

2 files changed

+93
-50
lines changed

2 files changed

+93
-50
lines changed

mssql_python/auth.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import platform
88
import struct
99
from typing import Tuple, Dict, Optional, Union
10-
from mssql_python.logging_config import get_logger
10+
from mssql_python.logging_config import get_logger, ENABLE_LOGGING
1111
from mssql_python.constants import AuthType
1212

1313
logger = get_logger()
@@ -22,42 +22,38 @@ def get_token_struct(token: str) -> bytes:
2222
return struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes)
2323

2424
@staticmethod
25-
def get_default_token() -> bytes:
26-
"""Get token using DefaultAzureCredential"""
27-
from azure.identity import DefaultAzureCredential
28-
29-
try:
30-
# DefaultAzureCredential will automatically use the best available method
31-
# based on the environment (e.g., managed identity, environment variables)
32-
credential = DefaultAzureCredential()
33-
token = credential.get_token("https://database.windows.net/.default").token
34-
return AADAuth.get_token_struct(token)
35-
except Exception as e:
36-
raise RuntimeError(f"Failed to create DefaultAzureCredential: {e}")
37-
38-
@staticmethod
39-
def get_device_code_token() -> bytes:
40-
"""Get token using DeviceCodeCredential"""
41-
from azure.identity import DeviceCodeCredential
42-
43-
try:
44-
credential = DeviceCodeCredential()
45-
token = credential.get_token("https://database.windows.net/.default").token
46-
return AADAuth.get_token_struct(token)
47-
except Exception as e:
48-
raise RuntimeError(f"Failed to create DeviceCodeCredential: {e}")
49-
50-
@staticmethod
51-
def get_interactive_token() -> bytes:
52-
"""Get token using InteractiveBrowserCredential"""
53-
from azure.identity import InteractiveBrowserCredential
25+
def get_token(auth_type: str) -> bytes:
26+
"""Get token using the specified authentication type"""
27+
from azure.identity import (
28+
DefaultAzureCredential,
29+
DeviceCodeCredential,
30+
InteractiveBrowserCredential
31+
)
32+
from azure.core.exceptions import ClientAuthenticationError
33+
34+
# Mapping of auth types to credential classes
35+
credential_map = {
36+
"default": DefaultAzureCredential,
37+
"devicecode": DeviceCodeCredential,
38+
"interactive": InteractiveBrowserCredential,
39+
}
40+
41+
credential_class = credential_map[auth_type]
5442

5543
try:
56-
credential = InteractiveBrowserCredential()
44+
credential = credential_class()
5745
token = credential.get_token("https://database.windows.net/.default").token
5846
return AADAuth.get_token_struct(token)
47+
except ClientAuthenticationError as e:
48+
# Re-raise with more specific context about Azure AD authentication failure
49+
raise RuntimeError(
50+
f"Azure AD authentication failed for {credential_class.__name__}: {e}. "
51+
f"This could be due to invalid credentials, missing environment variables, "
52+
f"user cancellation, network issues, or unsupported configuration."
53+
) from e
5954
except Exception as e:
60-
raise RuntimeError(f"Failed to create InteractiveBrowserCredential: {e}")
55+
# Catch any other unexpected exceptions
56+
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e
6157

6258
def process_auth_parameters(parameters: list) -> Tuple[list, Optional[str]]:
6359
"""
@@ -120,15 +116,14 @@ def get_auth_token(auth_type: str) -> Optional[bytes]:
120116
if not auth_type:
121117
return None
122118

123-
if auth_type == "default":
124-
return AADAuth.get_default_token()
125-
elif auth_type == "devicecode":
126-
return AADAuth.get_device_code_token()
127-
# If interactive authentication is requested, use InteractiveBrowserCredential
128-
# but only if not on Windows, since in Windows: AADInteractive is supported.
129-
elif auth_type == "interactive" and platform.system().lower() != "windows":
130-
return AADAuth.get_interactive_token()
131-
return None
119+
# Handle platform-specific logic for interactive auth
120+
if auth_type == "interactive" and platform.system().lower() == "windows":
121+
return None # Let Windows handle AADInteractive natively
122+
123+
try:
124+
return AADAuth.get_token(auth_type)
125+
except (ValueError, RuntimeError):
126+
return None
132127

133128
def process_connection_string(connection_string: str) -> Tuple[str, Optional[Dict]]:
134129
"""

tests/test_008_auth.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,34 @@ class MockInteractiveBrowserCredential:
3737
def get_token(self, scope):
3838
return MockToken()
3939

40+
# Mock ClientAuthenticationError
41+
class MockClientAuthenticationError(Exception):
42+
pass
43+
4044
class MockIdentity:
4145
DefaultAzureCredential = MockDefaultAzureCredential
4246
DeviceCodeCredential = MockDeviceCodeCredential
4347
InteractiveBrowserCredential = MockInteractiveBrowserCredential
4448

49+
class MockCore:
50+
class exceptions:
51+
ClientAuthenticationError = MockClientAuthenticationError
52+
4553
# Create mock azure module if it doesn't exist
4654
if 'azure' not in sys.modules:
4755
sys.modules['azure'] = type('MockAzure', (), {})()
4856

49-
# Add identity module to azure
57+
# Add identity and core modules to azure
5058
sys.modules['azure.identity'] = MockIdentity()
59+
sys.modules['azure.core'] = MockCore()
60+
sys.modules['azure.core.exceptions'] = MockCore.exceptions()
5161

5262
yield
5363

5464
# Cleanup
55-
if 'azure.identity' in sys.modules:
56-
del sys.modules['azure.identity']
65+
for module in ['azure.identity', 'azure.core', 'azure.core.exceptions']:
66+
if module in sys.modules:
67+
del sys.modules[module]
5768

5869
class TestAuthType:
5970
def test_auth_type_constants(self):
@@ -67,18 +78,55 @@ def test_get_token_struct(self):
6778
assert isinstance(token_struct, bytes)
6879
assert len(token_struct) > 4
6980

70-
def test_get_default_token(self):
71-
token_struct = AADAuth.get_default_token()
81+
def test_get_token_default(self):
82+
token_struct = AADAuth.get_token("default")
7283
assert isinstance(token_struct, bytes)
7384

74-
def test_get_device_code_token(self):
75-
token_struct = AADAuth.get_device_code_token()
85+
def test_get_token_device_code(self):
86+
token_struct = AADAuth.get_token("devicecode")
7687
assert isinstance(token_struct, bytes)
7788

78-
def test_get_interactive_token(self):
79-
token_struct = AADAuth.get_interactive_token()
89+
def test_get_token_interactive(self):
90+
token_struct = AADAuth.get_token("interactive")
8091
assert isinstance(token_struct, bytes)
8192

93+
def test_get_token_credential_mapping(self):
94+
# Test that all supported auth types work
95+
supported_types = ["default", "devicecode", "interactive"]
96+
for auth_type in supported_types:
97+
token_struct = AADAuth.get_token(auth_type)
98+
assert isinstance(token_struct, bytes)
99+
assert len(token_struct) > 4
100+
101+
def test_get_token_client_authentication_error(self):
102+
"""Test that ClientAuthenticationError is properly handled"""
103+
from azure.core.exceptions import ClientAuthenticationError
104+
105+
# Create a mock credential that raises ClientAuthenticationError
106+
class MockFailingCredential:
107+
def get_token(self, scope):
108+
raise ClientAuthenticationError("Mock authentication failed")
109+
110+
# Use monkeypatch to mock the credential creation
111+
def mock_get_token_failing(auth_type):
112+
from azure.core.exceptions import ClientAuthenticationError
113+
if auth_type == "default":
114+
try:
115+
credential = MockFailingCredential()
116+
token = credential.get_token("https://database.windows.net/.default").token
117+
return AADAuth.get_token_struct(token)
118+
except ClientAuthenticationError as e:
119+
raise RuntimeError(
120+
f"Azure AD authentication failed for MockFailingCredential: {e}. "
121+
f"This could be due to invalid credentials, missing environment variables, "
122+
f"user cancellation, network issues, or unsupported configuration."
123+
) from e
124+
else:
125+
return AADAuth.get_token(auth_type)
126+
127+
with pytest.raises(RuntimeError, match="Azure AD authentication failed"):
128+
mock_get_token_failing("default")
129+
82130
class TestProcessAuthParameters:
83131
def test_empty_parameters(self):
84132
modified_params, auth_type = process_auth_parameters([])

0 commit comments

Comments
 (0)