Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 84 additions & 19 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

logger = get_logger()


class Cursor:
"""
Represents a database cursor, which is used to manage the context of a fetch operation.
Expand Down Expand Up @@ -631,37 +630,103 @@ def execute(
# Initialize description after execution
self._initialize_description()

@staticmethod
def _select_best_sample_value(column):
"""
Selects the most representative non-null value from a column for type inference.

This is used during executemany() to infer SQL/C types based on actual data,
preferring a non-null value that is not the first row to avoid bias from placeholder defaults.

Args:
column: List of values in the column.
"""
non_nulls = [v for v in column if v is not None]
if not non_nulls:
return None
if all(isinstance(v, int) for v in non_nulls):
# Pick the value with the widest range (min/max)
return max(non_nulls, key=lambda v: abs(v))
if all(isinstance(v, float) for v in non_nulls):
return 0.0
if all(isinstance(v, decimal.Decimal) for v in non_nulls):
return max(non_nulls, key=lambda d: len(d.as_tuple().digits))
if all(isinstance(v, str) for v in non_nulls):
return max(non_nulls, key=lambda s: len(str(s)))
if all(isinstance(v, datetime.datetime) for v in non_nulls):
return datetime.datetime.now()
if all(isinstance(v, datetime.date) for v in non_nulls):
return datetime.date.today()
return non_nulls[0] # fallback

def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
"""
Convert list of rows (row-wise) into list of columns (column-wise),
for array binding via ODBC.
Args:
seq_of_parameters: Sequence of sequences or mappings of parameters.
"""
if not seq_of_parameters:
return []

num_params = len(seq_of_parameters[0])
columnwise = [[] for _ in range(num_params)]
for row in seq_of_parameters:
if len(row) != num_params:
raise ValueError("Inconsistent parameter row size in executemany()")
for i, val in enumerate(row):
columnwise[i].append(val)
return columnwise

def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Prepare a database operation and execute it against all parameter sequences.

This version uses column-wise parameter binding and a single batched SQLExecute().
Args:
operation: SQL query or command.
seq_of_parameters: Sequence of sequences or mappings of parameters.

Raises:
Error: If the operation fails.
"""
self._check_closed() # Check if the cursor is closed

self._check_closed()
self._reset_cursor()

first_execution = True
total_rowcount = 0
for parameters in seq_of_parameters:
parameters = list(parameters)
if ENABLE_LOGGING:
logger.info("Executing query with parameters: %s", parameters)
prepare_stmt = first_execution
first_execution = False
self.execute(
operation, parameters, use_prepare=prepare_stmt, reset_cursor=False
if not seq_of_parameters:
self.rowcount = 0
return

param_info = ddbc_bindings.ParamInfo
param_count = len(seq_of_parameters[0])
parameters_type = []

for col_index in range(param_count):
column = [row[col_index] for row in seq_of_parameters]
sample_value = self._select_best_sample_value(column)
dummy_row = list(seq_of_parameters[0])
parameters_type.append(
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index)
)
if self.rowcount != -1:
total_rowcount += self.rowcount
else:
total_rowcount = -1
self.rowcount = total_rowcount

columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters)
if ENABLE_LOGGING:
logger.info("Executing batch query with %d parameter sets:\n%s",
len(seq_of_parameters),"\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters))
)

# Execute batched statement
ret = ddbc_bindings.SQLExecuteMany(
self.hstmt,
operation,
columnwise_params,
parameters_type,
len(seq_of_parameters)
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
self.last_executed_stmt = operation
self._initialize_description()

def fetchone(self) -> Union[None, Row]:
"""
Expand Down
Loading