Skip to content

Commit 621391a

Browse files
committed
Forward all connection string params to py-core in bulkcopy
1 parent d532b8e commit 621391a

File tree

2 files changed

+83
-39
lines changed

2 files changed

+83
-39
lines changed

mssql_python/cursor.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
from typing import List, Union, Any, Optional, Tuple, Sequence, TYPE_CHECKING, Iterable
1919
from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes
20-
from mssql_python.helpers import check_error
20+
from mssql_python.helpers import check_error, connstr_to_pycore_params
2121
from mssql_python.logging import logger
2222
from mssql_python import ddbc_bindings
2323
from mssql_python.exceptions import (
@@ -2451,6 +2451,7 @@ def nextset(self) -> Union[bool, None]:
24512451
)
24522452
return True
24532453

2454+
# ── Mapping from ODBC connection-string keywords (lowercase, as _parse returns)
24542455
def _bulkcopy(
24552456
self,
24562457
table_name: str,
@@ -2585,38 +2586,10 @@ def _bulkcopy(
25852586
"Specify the target database explicitly to avoid accidentally writing to system databases."
25862587
)
25872588

2588-
# Build connection context for bulk copy library
2589-
# Note: Password is extracted separately to avoid storing it in the main context
2590-
# dict that could be accidentally logged or exposed in error messages.
2591-
trust_cert = params.get("trustservercertificate", "yes").lower() in ("yes", "true")
2592-
2593-
# Parse encryption setting from connection string
2594-
encrypt_param = params.get("encrypt")
2595-
if encrypt_param is not None:
2596-
encrypt_value = encrypt_param.strip().lower()
2597-
if encrypt_value in ("yes", "true", "mandatory", "required"):
2598-
encryption = "Required"
2599-
elif encrypt_value in ("no", "false", "optional"):
2600-
encryption = "Optional"
2601-
else:
2602-
# Pass through unrecognized values (e.g., "Strict") to the underlying driver
2603-
encryption = encrypt_param
2604-
else:
2605-
encryption = "Optional"
2606-
2607-
context = {
2608-
"server": params.get("server"),
2609-
"database": params.get("database"),
2610-
"trust_server_certificate": trust_cert,
2611-
"encryption": encryption,
2612-
}
2613-
2614-
# Build pycore_context with appropriate authentication.
2615-
# For Azure AD: acquire a FRESH token right now instead of reusing
2616-
# the one from connect() time — avoids expired-token errors when
2617-
# bulkcopy() is called long after the original connection.
2618-
pycore_context = dict(context)
2589+
# Translate parsed connection string into the dict py-core expects.
2590+
pycore_context = connstr_to_pycore_params(params)
26192591

2592+
# Token acquisition — only thing cursor must handle (needs azure-identity SDK)
26202593
if self.connection._auth_type:
26212594
# Fresh token acquisition for mssql-py-core connection
26222595
from mssql_python.auth import AADAuth
@@ -2633,10 +2606,6 @@ def _bulkcopy(
26332606
"Bulk copy: acquired fresh Azure AD token for auth_type=%s",
26342607
self.connection._auth_type,
26352608
)
2636-
else:
2637-
# SQL Server authentication — use uid/password from connection string
2638-
pycore_context["user_name"] = params.get("uid", "")
2639-
pycore_context["password"] = params.get("pwd", "")
26402609

26412610
pycore_connection = None
26422611
pycore_cursor = None
@@ -2675,9 +2644,8 @@ def _bulkcopy(
26752644
finally:
26762645
# Clear sensitive data to minimize memory exposure
26772646
if pycore_context:
2678-
pycore_context.pop("password", None)
2679-
pycore_context.pop("user_name", None)
2680-
pycore_context.pop("access_token", None)
2647+
for key in ("password", "user_name", "access_token"):
2648+
pycore_context.pop(key, None)
26812649
# Clean up bulk copy resources
26822650
for resource in (pycore_cursor, pycore_connection):
26832651
if resource and hasattr(resource, "close"):

mssql_python/helpers.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,82 @@ def _sanitize_for_logging(input_val: Any, max_length: int = max_log_length) -> s
250250
return True, None, sanitized_attr, sanitized_val
251251

252252

253+
def connstr_to_pycore_params(params: dict) -> dict:
254+
"""Translate parsed connection-string params into the dict that py-core's
255+
``connection.rs`` expects.
256+
257+
*params* uses the lowercase keys that ``_ConnectionStringParser._parse``
258+
returns. This function maps them to py-core's snake_case keys and
259+
converts str values to bool/int where the Rust side expects native types.
260+
Keys not recognised by py-core are silently dropped.
261+
"""
262+
# connstr key (lowercase) → py-core dict key
263+
key_map = {
264+
# auth / credentials
265+
"uid": "user_name",
266+
"pwd": "password",
267+
"trusted_connection": "trusted_connection",
268+
"authentication": "authentication",
269+
# server (accept parser synonyms)
270+
"server": "server",
271+
"addr": "server",
272+
"address": "server",
273+
# database
274+
"database": "database",
275+
"app": "application_name",
276+
"applicationintent": "application_intent",
277+
"workstationid": "workstation_id",
278+
"language": "language",
279+
# encryption / TLS (include snake_case alias the parser may emit)
280+
"encrypt": "encryption",
281+
"trustservercertificate": "trust_server_certificate",
282+
"trust_server_certificate": "trust_server_certificate",
283+
"hostnameincertificate": "host_name_in_certificate",
284+
"servercertificate": "server_certificate",
285+
# Kerberos
286+
"serverspn": "server_spn",
287+
# network
288+
"multisubnetfailover": "multi_subnet_failover",
289+
"ipaddresspreference": "ip_address_preference",
290+
"keepalive": "keep_alive",
291+
"keepaliveinterval": "keep_alive_interval",
292+
# sizing / limits ("packet size" with space is a common pyodbc-ism)
293+
"packetsize": "packet_size",
294+
"packet size": "packet_size",
295+
"connect_timeout": "connect_timeout",
296+
"connectretrycount": "connect_retry_count",
297+
"connectretryinterval": "connect_retry_interval",
298+
# MARS
299+
"mars_connection": "mars_enabled",
300+
}
301+
bool_keys = {"trust_server_certificate", "multi_subnet_failover", "mars_enabled"}
302+
int_keys = {
303+
"connect_timeout",
304+
"packet_size",
305+
"connect_retry_count",
306+
"connect_retry_interval",
307+
"keep_alive",
308+
"keep_alive_interval",
309+
}
310+
311+
ctx: dict = {}
312+
for src, dst in key_map.items():
313+
val = params.get(src)
314+
if val is None:
315+
continue
316+
if dst in bool_keys:
317+
ctx[dst] = val.lower() in ("yes", "true", "1") if isinstance(val, str) else bool(val)
318+
elif dst in int_keys:
319+
try:
320+
ctx[dst] = int(val)
321+
except (ValueError, TypeError):
322+
pass # let py-core use its default
323+
else:
324+
ctx[dst] = val
325+
326+
return ctx
327+
328+
253329
# Settings functionality moved here to avoid circular imports
254330

255331
# Initialize the locale setting only once at module import time

0 commit comments

Comments
 (0)