Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ EntraID authentication is now fully supported on MacOS and Linux but with certai
| ActiveDirectoryServicePrincipal | ✅ Yes | ✅ Yes | Use client ID and secret or certificate |
| ActiveDirectoryIntegrated | ✅ Yes | ❌ No | Only works on Windows (requires Kerberos/SSPI) |

> **NOTE**: For using Access Token, the connection string *must not* contain `UID`, `PWD`, `Authentication`, or `Trusted_Connection` keywords.

### Enhanced Pythonic Features

The driver offers a suite of Pythonic enhancements that streamline database interactions, making it easier for developers to execute queries, manage connections, and handle data more efficiently.
Expand Down
29 changes: 23 additions & 6 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,17 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef
preparing it for further operations such as connecting to the
database, executing queries, etc.
"""
self.connection_str = self._construct_connection_string(
connection_str, **kwargs
)
self._attrs_before = attrs_before or {}
# Get connection string and potential attrs_before from construction
connection_result = self._construct_connection_string(connection_str, **kwargs)
if isinstance(connection_result, tuple):
self.connection_str, attrs_from_driver = connection_result
# Merge with any existing attrs_before
self._attrs_before = attrs_from_driver or {}
self._attrs_before.update(attrs_from_driver)
else:
self.connection_str = connection_result
self._attrs_before = attrs_before or {}

self._closed = False

# Using WeakSet which automatically removes cursors when they are no longer in use
Expand All @@ -90,10 +97,18 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st
**kwargs: Additional key/value pairs for the connection string.

Returns:
str: The constructed connection string.
Union[str, Tuple[str, dict]]: Either the constructed connection string,
or a tuple of (connection string, attrs_before dict)
"""
# Add the driver attribute to the connection string
conn_str = add_driver_to_connection_str(connection_str)
result = add_driver_to_connection_str(connection_str)

# Handle both string and tuple return types
if isinstance(result, tuple):
conn_str, attrs_before = result
else:
conn_str = result
attrs_before = None

# Add additional key-value pairs to the connection string
for key, value in kwargs.items():
Expand All @@ -116,6 +131,8 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st
if ENABLE_LOGGING:
logger.info("Final connection string: %s", conn_str)

if attrs_before is not None:
return conn_str, attrs_before
return conn_str

@property
Expand Down
71 changes: 57 additions & 14 deletions mssql_python/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@ def add_driver_to_connection_str(connection_str):
connection_str (str): The original connection string.

Returns:
str: The connection string with the DDBC driver added.

Raises:
Exception: If the connection string is invalid.
Union[str, Tuple[str, dict]]: Either the connection string with driver added,
or a tuple of (connection string, attrs_before dict)
"""
driver_name = "Driver={ODBC Driver 18 for SQL Server}"
try:
# Strip any leading or trailing whitespace from the connection string
connection_str = connection_str.strip()
connection_str = add_driver_name_to_app_parameter(connection_str)
result = add_driver_name_to_app_parameter(connection_str)

# Handle both regular string and tuple return types
attrs_before = None
if isinstance(result, tuple):
connection_str, attrs_before = result
else:
connection_str = result

# Split the connection string into individual attributes
connection_attributes = connection_str.split(";")
Expand All @@ -50,15 +55,16 @@ def add_driver_to_connection_str(connection_str):
final_connection_attributes.insert(0, driver_name)
connection_str = ";".join(final_connection_attributes)

if attrs_before is not None:
return connection_str, attrs_before
return connection_str

except Exception as e:
raise Exception(
"Invalid connection string, Please follow the format: "
"Server=server_name;Database=database_name;UID=user_name;PWD=password"
) from e

return connection_str


def check_error(handle_type, handle, ret):
"""
Check for errors and raise an exception if an error is found.
Expand All @@ -80,38 +86,75 @@ def check_error(handle_type, handle, ret):

def add_driver_name_to_app_parameter(connection_string):
"""
Modifies the input connection string by appending the APP name.
Modifies the input connection string by appending the APP name and handling AAD auth.

Args:
connection_string (str): The input connection string.

Returns:
str: The modified connection string.
Union[str, Tuple[str, bytes]]: Either the modified connection string,
or a tuple of (connection string, token bytes) if AAD auth is needed
"""
import sys

# Split the input string into key-value pairs
parameters = connection_string.split(";")

# Initialize variables
app_found = False
modified_parameters = []
has_aad_interactive = False

# Iterate through the key-value pairs
for param in parameters:
param = param.strip()
if not param:
continue

if param.lower().startswith("authentication="):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The auth logic is currently spread across multiple helpers and is resulting in cascading function changes and unstable signatures.
I recommend refactoring all authentication-related code into a dedicated mssql_python/auth.py module. This will centralize and simplify authentication logic and help maintain single responsibility in other modules (like helpers and connection).

# Handle AAD Interactive authentication
key, auth_value = param.split("=", 1)
if auth_value.lower() == "activedirectoryinteractive":
has_aad_interactive = True
# Only keep the auth parameter on Windows
if platform.system().lower() != "windows":
modified_parameters.append(param)
continue

if param.lower().startswith("app="):
# Overwrite the value with 'MSSQL-Python'
app_found = True
key, _ = param.split("=", 1)
modified_parameters.append(f"{key}=MSSQL-Python")
else:
# Keep other parameters as is
modified_parameters.append(param)

# If APP key is not found, append it
if not app_found:
modified_parameters.append("APP=MSSQL-Python")

# Join the parameters back into a connection string
return ";".join(modified_parameters) + ";"
# Handle AAD Interactive auth for non-Windows platforms
if has_aad_interactive and platform.system().lower() != "windows":

# Remove Uid, Pwd, Connection Timeout, Encrypt, TrustServerCertificate
modified_parameters = [
param for param in modified_parameters
if not any(key in param.lower() for key in ["uid=", "pwd=", "connection timeout=", "encrypt=", "trustservercertificate=", "authentication="])
]

try:
from azure.identity import InteractiveBrowserCredential
import struct
except ImportError:
raise ImportError("Please install azure-identity: pip install azure-identity")

credential = InteractiveBrowserCredential()
token_bytes = credential.get_token("https://database.windows.net/.default").token.encode("UTF-16-LE")
token_struct = struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes)
conn_str = ";".join(modified_parameters) + ";", {1256: token_struct}
return conn_str

conn_str = ";".join(modified_parameters) + ";"
return conn_str


def detect_linux_distro():
Expand Down
6 changes: 5 additions & 1 deletion mssql_python/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,8 @@ def __iter__(self):

def __repr__(self):
"""Return a string representation of the row"""
return f"Row{tuple(self._values)}"
return f"Row{tuple(self._values)}"

def __str__(self):
"""Return a string representation of the row"""
return f"Row({', '.join(map(str, self._values))})"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we wouldn't be needing Row(...) in the string representations, for end users the output should look like a tuple (e.g., (1, 2, 3)).

I suggest updating both the methods as follows:

def __str__(self):
    return str(tuple(self._values))

def __repr__(self):
    return repr(tuple(self._values))

this makes sure that printing or representing a Row instance makes it behave just like a tuple, which is the expected behaviour, also please add tests for this inside test cursor to test fetch functions results.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe these changes should be a separate PR of its own

4 changes: 2 additions & 2 deletions tests/test_003_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_construct_connection_string(db_connection):
assert "TrustServerCertificate=yes;" in conn_str, "Connection string should contain 'TrustServerCertificate=yes;'"
assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'"
assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'"
assert "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"
assert "Driver={ODBC Driver 18 for SQL Server};APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"

def test_connection_string_with_attrs_before(db_connection):
# Check if the connection string is constructed correctly with attrs_before
Expand All @@ -70,7 +70,7 @@ def test_connection_string_with_odbc_param(db_connection):
assert "TrustServerCertificate=yes;" in conn_str, "Connection string should contain 'TrustServerCertificate=yes;'"
assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'"
assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'"
assert "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"
assert "Driver={ODBC Driver 18 for SQL Server};APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"

def test_autocommit_default(db_connection):
assert db_connection.autocommit is True, "Autocommit should be True by default"
Expand Down