Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions mssql_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,37 @@
Licensed under the MIT license.
This module initializes the mssql_python package.
"""

import threading
# Exceptions
# https://www.python.org/dev/peps/pep-0249/#exceptions

# GLOBALS
# Read-Only
apilevel = "2.0"
paramstyle = "qmark"
threadsafety = 1

_settings_lock = threading.Lock()

# Create a settings object to hold configuration
class Settings:
def __init__(self):
self.lowercase = False

# Create a global settings instance
_settings = Settings()

# Define the get_settings function for internal use
def get_settings():
"""Return the global settings object"""
with _settings_lock:
_settings.lowercase = lowercase
return _settings

# Expose lowercase as a regular module variable that users can access and set
lowercase = _settings.lowercase

# Import necessary modules
from .exceptions import (
Warning,
Error,
Expand Down Expand Up @@ -52,12 +80,6 @@
SQL_WCHAR = ConstantsDDBC.SQL_WCHAR.value
SQL_WMETADATA = -99

# GLOBALS
# Read-Only
apilevel = "2.0"
paramstyle = "qmark"
threadsafety = 1

from .pooling import PoolingManager
def pooling(max_size=100, idle_timeout=600, enabled=True):
# """
Expand All @@ -76,3 +98,18 @@ def pooling(max_size=100, idle_timeout=600, enabled=True):
PoolingManager.disable()
else:
PoolingManager.enable(max_size, idle_timeout)

import sys
_original_module_setattr = sys.modules[__name__].__setattr__

def _custom_setattr(name, value):
if name == 'lowercase':
with _settings_lock:
_settings.lowercase = bool(value)
# Update the module's lowercase variable
_original_module_setattr(name, _settings.lowercase)
else:
_original_module_setattr(name, value)

# Replace the module's __setattr__ with our custom version
sys.modules[__name__].__setattr__ = _custom_setattr
175 changes: 113 additions & 62 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from mssql_python.helpers import check_error, log
from mssql_python import ddbc_bindings
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError
from .row import Row
from mssql_python.row import Row
from mssql_python import get_settings

# Constants for string handling
MAX_INLINE_CHAR = 4000 # NVARCHAR/VARCHAR inline limit; this triggers NVARCHAR(MAX)/VARCHAR(MAX) + DAE
Expand Down Expand Up @@ -543,26 +544,32 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i

return paraminfo

def _initialize_description(self):
"""
Initialize the description attribute using SQLDescribeCol.
"""
col_metadata = []
ret = ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, col_metadata)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
def _initialize_description(self, column_metadata=None):
"""Initialize the description attribute from column metadata."""
if not column_metadata:
self.description = None
return

self.description = [
(
col["ColumnName"],
self._map_data_type(col["DataType"]),
None,
col["ColumnSize"],
col["ColumnSize"],
col["DecimalDigits"],
col["Nullable"] == ddbc_sql_const.SQL_NULLABLE.value,
)
for col in col_metadata
]
description = []
for i, col in enumerate(column_metadata):
# Get column name - lowercase it if the lowercase flag is set
column_name = col["ColumnName"]

# Use the current global setting to ensure tests pass correctly
if get_settings().lowercase:
column_name = column_name.lower()

# Add to description tuple (7 elements as per PEP-249)
description.append((
column_name, # name
self._map_data_type(col["DataType"]), # type_code
None, # display_size
col["ColumnSize"], # internal_size
col["ColumnSize"], # precision - should match ColumnSize
col["DecimalDigits"], # scale
col["Nullable"] == ddbc_sql_const.SQL_NULLABLE.value, # null_ok
))
self.description = description

def _map_data_type(self, sql_type):
"""
Expand Down Expand Up @@ -746,6 +753,16 @@ def execute(
use_prepare: Whether to use SQLPrepareW (default) or SQLExecDirectW.
reset_cursor: Whether to reset the cursor before execution.
"""

# Restore original fetch methods if they exist
if hasattr(self, '_original_fetchone'):
self.fetchone = self._original_fetchone
self.fetchmany = self._original_fetchmany
self.fetchall = self._original_fetchall
del self._original_fetchone
del self._original_fetchmany
del self._original_fetchall

self._check_closed() # Check if the cursor is closed
if reset_cursor:
self._reset_cursor()
Expand Down Expand Up @@ -822,7 +839,14 @@ def execute(
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)

# Initialize description after execution
self._initialize_description()
# After successful execution, initialize description if there are results
column_metadata = []
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except Exception as e:
# If describe fails, it's likely there are no results (e.g., for INSERT)
self.description = None

# Reset rownumber for new result set (only for SELECT statements)
if self.description: # If we have column descriptions, it's likely a SELECT
Expand Down Expand Up @@ -975,7 +999,7 @@ def fetchone(self) -> Union[None, Row]:

# Create and return a Row object, passing column name map if available
column_map = getattr(self, '_column_name_map', None)
return Row(row_data, self.description, column_map)
return Row(self, self.description, row_data, column_map)
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand Down Expand Up @@ -1017,7 +1041,7 @@ def fetchmany(self, size: int = None) -> List[Row]:

# Convert raw data to Row objects
column_map = getattr(self, '_column_name_map', None)
return [Row(row_data, self.description, column_map) for row_data in rows_data]
return [Row(self, self.description, row_data, column_map) for row_data in rows_data]
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand Down Expand Up @@ -1049,7 +1073,7 @@ def fetchall(self) -> List[Row]:

# Convert raw data to Row objects
column_map = getattr(self, '_column_name_map', None)
return [Row(row_data, self.description, column_map) for row_data in rows_data]
return [Row(self, self.description, row_data, column_map) for row_data in rows_data]
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand Down Expand Up @@ -1363,30 +1387,20 @@ def tables(self, table=None, catalog=None, schema=None, tableType=None):
Example: "TABLE" or ["TABLE", "VIEW"]

Returns:
list: A list of Row objects containing table information with these columns:
- table_cat: Catalog name
- table_schem: Schema name
- table_name: Table name
- table_type: Table type (e.g., "TABLE", "VIEW")
- remarks: Comments about the table

Notes:
This method only processes the standard five columns as defined in the ODBC
specification. Any additional columns that might be returned by specific ODBC
drivers are not included in the result set.

Cursor: The cursor object itself for method chaining with fetch methods.

Example:
# Get all tables in the database
tables = cursor.tables()
tables = cursor.tables().fetchall()

# Get all tables in schema 'dbo'
tables = cursor.tables(schema='dbo')
tables = cursor.tables(schema='dbo').fetchall()

# Get table named 'Customers'
tables = cursor.tables(table='Customers')
tables = cursor.tables(table='Customers').fetchone()

# Get all views
tables = cursor.tables(tableType='VIEW')
# Get all views with fetchmany
tables = cursor.tables(tableType='VIEW').fetchmany(10)
"""
self._check_closed()

Expand Down Expand Up @@ -1418,7 +1432,13 @@ def tables(self, table=None, catalog=None, schema=None, tableType=None):
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except Exception:
except InterfaceError as e:
log('error', f"Driver interface error during metadata retrieval: {e}")
except Exception as e:
# Log the exception with appropriate context
log('error', f"Failed to retrieve column metadata: {e}. Using standard ODBC column definitions instead.")

if not self.description:
# If describe fails, create a manual description for the standard columns
column_types = [str, str, str, str, str]
self.description = [
Expand All @@ -1428,23 +1448,54 @@ def tables(self, table=None, catalog=None, schema=None, tableType=None):
("table_type", column_types[3], None, 128, 128, 0, False),
("remarks", column_types[4], None, 254, 254, 0, True)
]

# Define column names in ODBC standard order
column_names = [
"table_cat", "table_schem", "table_name", "table_type", "remarks"
]

# Fetch all rows
rows_data = []
ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)

# Create a column map for attribute access
column_map = {name: i for i, name in enumerate(column_names)}

# Create Row objects with the column map
result_rows = []
for row_data in rows_data:
row = Row(row_data, self.description, column_map)
result_rows.append(row)

return result_rows

# Store the column mappings for this specific tables() call
column_names = [desc[0] for desc in self.description]

# Create a specialized column map for this result set
columns_map = {}
for i, name in enumerate(column_names):
columns_map[name] = i
columns_map[name.lower()] = i

# Define wrapped fetch methods that preserve existing column mapping
# but add our specialized mapping just for column results
def fetchone_with_columns_mapping():
row = self._original_fetchone()
if row is not None:
# Create a merged map with columns result taking precedence
merged_map = getattr(row, '_column_map', {}).copy()
merged_map.update(columns_map)
row._column_map = merged_map
return row

def fetchmany_with_columns_mapping(size=None):
rows = self._original_fetchmany(size)
for row in rows:
# Create a merged map with columns result taking precedence
merged_map = getattr(row, '_column_map', {}).copy()
merged_map.update(columns_map)
row._column_map = merged_map
return rows

def fetchall_with_columns_mapping():
rows = self._original_fetchall()
for row in rows:
# Create a merged map with columns result taking precedence
merged_map = getattr(row, '_column_map', {}).copy()
merged_map.update(columns_map)
row._column_map = merged_map
return rows

# Save original fetch methods
if not hasattr(self, '_original_fetchone'):
self._original_fetchone = self.fetchone
self._original_fetchmany = self.fetchmany
self._original_fetchall = self.fetchall

# Override fetch methods with our wrapped versions
self.fetchone = fetchone_with_columns_mapping
self.fetchmany = fetchmany_with_columns_mapping
self.fetchall = fetchall_with_columns_mapping

return self
Loading