Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError
from .row import Row

# Constants for string handling
MAX_INLINE_CHAR = 4000 # NVARCHAR/VARCHAR inline limit; this triggers NVARCHAR(MAX)/VARCHAR(MAX) + DAE

class Cursor:
"""
Expand Down Expand Up @@ -233,10 +235,11 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_DEFAULT.value,
1,
0,
False,
)

if isinstance(param, bool):
return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0
return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0, False

if isinstance(param, int):
if 0 <= param <= 255:
Expand All @@ -245,26 +248,30 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_TINYINT.value,
3,
0,
False,
)
if -32768 <= param <= 32767:
return (
ddbc_sql_const.SQL_SMALLINT.value,
ddbc_sql_const.SQL_C_SHORT.value,
5,
0,
False,
)
if -2147483648 <= param <= 2147483647:
return (
ddbc_sql_const.SQL_INTEGER.value,
ddbc_sql_const.SQL_C_LONG.value,
10,
0,
False,
)
return (
ddbc_sql_const.SQL_BIGINT.value,
ddbc_sql_const.SQL_C_SBIGINT.value,
19,
0,
False,
)

if isinstance(param, float):
Expand All @@ -273,6 +280,7 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_DOUBLE.value,
15,
0,
False,
)

if isinstance(param, decimal.Decimal):
Expand All @@ -284,6 +292,7 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_NUMERIC.value,
parameters_list[i].precision,
parameters_list[i].scale,
False,
)

if isinstance(param, str):
Expand All @@ -297,6 +306,7 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_WCHAR.value,
len(param),
0,
False,
)

# Attempt to parse as date, datetime, datetime2, timestamp, smalldatetime or time
Expand All @@ -309,6 +319,7 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_TYPE_DATE.value,
10,
0,
False,
)
if self._parse_datetime(param):
parameters_list[i] = self._parse_datetime(param)
Expand All @@ -317,6 +328,7 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value,
26,
6,
False,
)
if self._parse_time(param):
parameters_list[i] = self._parse_time(param)
Expand All @@ -325,25 +337,26 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_TYPE_TIME.value,
8,
0,
False,
)

# String mapping logic here
is_unicode = self._is_unicode_string(param)
# TODO: revisit
if len(param) > 4000: # Long strings
if len(param) > MAX_INLINE_CHAR: # Long strings
if is_unicode:
utf16_len = len(param.encode("utf-16-le")) // 2
return (
ddbc_sql_const.SQL_WLONGVARCHAR.value,
ddbc_sql_const.SQL_C_WCHAR.value,
utf16_len,
len(param),
0,
True,
)
return (
ddbc_sql_const.SQL_LONGVARCHAR.value,
ddbc_sql_const.SQL_C_CHAR.value,
len(param),
0,
True,
)
if is_unicode: # Short Unicode strings
utf16_len = len(param.encode("utf-16-le")) // 2
Expand All @@ -352,12 +365,14 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_WCHAR.value,
utf16_len,
0,
False,
)
return (
ddbc_sql_const.SQL_VARCHAR.value,
ddbc_sql_const.SQL_C_CHAR.value,
len(param),
0,
False,
)

if isinstance(param, bytes):
Expand All @@ -367,12 +382,14 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_BINARY.value,
len(param),
0,
False,
)
return (
ddbc_sql_const.SQL_BINARY.value,
ddbc_sql_const.SQL_C_BINARY.value,
len(param),
0,
False,
)

if isinstance(param, bytearray):
Expand All @@ -382,12 +399,14 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_BINARY.value,
len(param),
0,
True,
)
return (
ddbc_sql_const.SQL_BINARY.value,
ddbc_sql_const.SQL_C_BINARY.value,
len(param),
0,
False,
)

if isinstance(param, datetime.datetime):
Expand All @@ -396,6 +415,7 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value,
26,
6,
False,
)

if isinstance(param, datetime.date):
Expand All @@ -404,6 +424,7 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_TYPE_DATE.value,
10,
0,
False,
)

if isinstance(param, datetime.time):
Expand All @@ -412,14 +433,11 @@ def _map_sql_type(self, param, parameters_list, i):
ddbc_sql_const.SQL_C_TYPE_TIME.value,
8,
0,
False,
)

return (
ddbc_sql_const.SQL_VARCHAR.value,
ddbc_sql_const.SQL_C_CHAR.value,
len(str(param)),
0,
)
# For safety: unknown/unhandled Python types should not silently go to SQL
raise TypeError("Unsupported parameter type: The driver cannot safely convert it to a SQL type.")

def _initialize_cursor(self) -> None:
"""
Expand Down Expand Up @@ -495,14 +513,19 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i
paraminfo.
"""
paraminfo = param_info()
sql_type, c_type, column_size, decimal_digits = self._map_sql_type(
sql_type, c_type, column_size, decimal_digits, is_dae = self._map_sql_type(
parameter, parameters_list, i
)
paraminfo.paramCType = c_type
paraminfo.paramSQLType = sql_type
paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value
paraminfo.columnSize = column_size
paraminfo.decimalDigits = decimal_digits
paraminfo.isDAE = is_dae

if is_dae:
paraminfo.dataPtr = parameter # Will be converted to py::object* in C++

return paraminfo

def _initialize_description(self):
Expand Down Expand Up @@ -762,9 +785,16 @@ def execute(
self.is_stmt_prepared,
use_prepare,
)

# Check return code
try:

# Check for errors but don't raise exceptions for info/warning messages
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
except Exception as e:
log('warning', "Execute failed, resetting cursor: %s", e)
self._reset_cursor()
raise


# Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.)
if self.hstmt:
Expand Down
Loading