Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,10 @@ def _map_sql_type(self, param, parameters_list, i):

# String mapping logic here
is_unicode = self._is_unicode_string(param)
if len(param) > MAX_INLINE_CHAR: # Long strings

# Computes UTF-16 code units (handles surrogate pairs)
utf16_len = sum(2 if ord(c) > 0xFFFF else 1 for c in param)
if utf16_len > MAX_INLINE_CHAR: # Long strings -> DAE
if is_unicode:
return (
ddbc_sql_const.SQL_WLONGVARCHAR.value,
Expand All @@ -358,8 +361,9 @@ def _map_sql_type(self, param, parameters_list, i):
0,
True,
)
if is_unicode: # Short Unicode strings
utf16_len = len(param.encode("utf-16-le")) // 2

# Short strings
if is_unicode:
return (
ddbc_sql_const.SQL_WVARCHAR.value,
ddbc_sql_const.SQL_C_WCHAR.value,
Expand All @@ -374,7 +378,7 @@ def _map_sql_type(self, param, parameters_list, i):
0,
False,
)

if isinstance(param, bytes):
if len(param) > 8000: # Assuming VARBINARY(MAX) for long byte arrays
return (
Expand Down
71 changes: 54 additions & 17 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,27 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,

// TODO: Add more data types like money, guid, interval, TVPs etc.
switch (paramInfo.paramCType) {
case SQL_C_CHAR:
case SQL_C_CHAR: {
if (!py::isinstance<py::str>(param) && !py::isinstance<py::bytearray>(param) &&
!py::isinstance<py::bytes>(param)) {
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
}
if (paramInfo.isDAE) {
LOG("Parameter[{}] is marked for DAE streaming", paramIndex);
dataPtr = const_cast<void*>(reinterpret_cast<const void*>(&paramInfos[paramIndex]));
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
*strLenOrIndPtr = SQL_LEN_DATA_AT_EXEC(0);
bufferLength = 0;
} else {
std::string* strParam =
AllocateParamBuffer<std::string>(paramBuffers, param.cast<std::string>());
dataPtr = const_cast<void*>(static_cast<const void*>(strParam->c_str()));
bufferLength = strParam->size() + 1;
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
*strLenOrIndPtr = SQL_NTS;
}
break;
}
case SQL_C_BINARY: {
if (!py::isinstance<py::str>(param) && !py::isinstance<py::bytearray>(param) &&
!py::isinstance<py::bytes>(param)) {
Expand Down Expand Up @@ -1203,23 +1223,40 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle,
continue;
}
if (py::isinstance<py::str>(pyObj)) {
std::wstring wstr = pyObj.cast<std::wstring>();
#if defined(__APPLE__) || defined(__linux__)
auto utf16Buf = WStringToSQLWCHAR(wstr);
const char* dataPtr = reinterpret_cast<const char*>(utf16Buf.data());
size_t totalBytes = (utf16Buf.size() - 1) * sizeof(SQLWCHAR);
#else
const char* dataPtr = reinterpret_cast<const char*>(wstr.data());
size_t totalBytes = wstr.size() * sizeof(wchar_t);
#endif
const size_t chunkSize = DAE_CHUNK_SIZE;
for (size_t offset = 0; offset < totalBytes; offset += chunkSize) {
size_t len = std::min(chunkSize, totalBytes - offset);
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(len));
if (!SQL_SUCCEEDED(rc)) {
LOG("SQLPutData failed at offset {} of {}", offset, totalBytes);
return rc;
if (matchedInfo->paramCType == SQL_C_WCHAR) {
std::wstring wstr = pyObj.cast<std::wstring>();
std::vector<SQLWCHAR> sqlwStr = WStringToSQLWCHAR(wstr);
size_t totalChars = sqlwStr.size() - 1;
const SQLWCHAR* dataPtr = sqlwStr.data();
size_t offset = 0;
size_t chunkChars = DAE_CHUNK_SIZE / sizeof(SQLWCHAR);
while (offset < totalChars) {
size_t len = std::min(chunkChars, totalChars - offset);
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(len));
if (!SQL_SUCCEEDED(rc)) {
LOG("SQLPutData failed at offset {} of {}", offset, totalChars);
return rc;
}
offset += len;
}
} else if (matchedInfo->paramCType == SQL_C_CHAR) {
std::string s = pyObj.cast<std::string>();
size_t totalBytes = s.size();
const char* dataPtr = s.data();
size_t offset = 0;
size_t chunkBytes = DAE_CHUNK_SIZE;
while (offset < totalBytes) {
size_t len = std::min(chunkBytes, totalBytes - offset);

rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(len));
if (!SQL_SUCCEEDED(rc)) {
LOG("SQLPutData failed at offset {} of {}", offset, totalBytes);
return rc;
}
offset += len;
}
} else {
ThrowStdException("Unsupported C type for str in DAE");
}
} else {
ThrowStdException("DAE only supported for str or bytes");
Expand Down
154 changes: 154 additions & 0 deletions tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5124,6 +5124,160 @@ def test_emoji_round_trip(cursor, db_connection):
except Exception as e:
pytest.fail(f"Error for input {repr(text)}: {e}")

def test_varchar_max_insert_non_lob(cursor, db_connection):
"""Test small VARCHAR(MAX) insert (non-LOB path)."""
try:
cursor.execute("CREATE TABLE #pytest_varchar_nonlob (col VARCHAR(MAX))")
db_connection.commit()

small_str = "Hello, world!" # small, non-LOB
cursor.execute(
"INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)",
[small_str]
)
db_connection.commit()

empty_str = ""
cursor.execute(
"INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)",
[empty_str]
)
db_connection.commit()

# None value
cursor.execute(
"INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)",
[None]
)
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_varchar_nonlob")
# rows = cursor.fetchall()
# assert rows == [[small_str], [empty_str], [None]]

finally:
pass


def test_varchar_max_insert_lob(cursor, db_connection):
"""Test large VARCHAR(MAX) insert (LOB path)."""
try:
cursor.execute("CREATE TABLE #pytest_varchar_lob (col VARCHAR(MAX))")
db_connection.commit()

large_str = "A" * 100_000 # > 8k to trigger LOB
cursor.execute(
"INSERT INTO #pytest_varchar_lob (col) VALUES (?)",
[large_str]
)
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_varchar_lob")
# rows = cursor.fetchall()
# assert rows == [[large_str]]

finally:
pass


def test_nvarchar_max_insert_non_lob(cursor, db_connection):
"""Test small NVARCHAR(MAX) insert (non-LOB path)."""
try:
cursor.execute("CREATE TABLE #pytest_nvarchar_nonlob (col NVARCHAR(MAX))")
db_connection.commit()

small_str = "Unicode ✨ test"
cursor.execute(
"INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)",
[small_str]
)
db_connection.commit()

empty_str = ""
cursor.execute(
"INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)",
[empty_str]
)
db_connection.commit()

cursor.execute(
"INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)",
[None]
)
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_nvarchar_nonlob")
# rows = cursor.fetchall()
# assert rows == [[small_str], [empty_str], [None]]

finally:
pass


def test_nvarchar_max_insert_lob(cursor, db_connection):
"""Test large NVARCHAR(MAX) insert (LOB path)."""
try:
cursor.execute("CREATE TABLE #pytest_nvarchar_lob (col NVARCHAR(MAX))")
db_connection.commit()

large_str = "📝" * 50_000 # each emoji = 2 UTF-16 code units, total > 100k bytes
cursor.execute(
"INSERT INTO #pytest_nvarchar_lob (col) VALUES (?)",
[large_str]
)
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_nvarchar_lob")
# rows = cursor.fetchall()
# assert rows == [[large_str]]

finally:
pass

def test_nvarchar_max_boundary(cursor, db_connection):
"""Test NVARCHAR(MAX) at LOB boundary sizes."""
try:
cursor.execute("DROP TABLE IF EXISTS #pytest_nvarchar_boundary")
cursor.execute("CREATE TABLE #pytest_nvarchar_boundary (col NVARCHAR(MAX))")
db_connection.commit()

# 4k BMP chars = 8k bytes
cursor.execute("INSERT INTO #pytest_nvarchar_boundary (col) VALUES (?)", ["A" * 4096])
# 4k emojis = 8k UTF-16 code units (16k bytes)
cursor.execute("INSERT INTO #pytest_nvarchar_boundary (col) VALUES (?)", ["📝" * 4096])
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_nvarchar_boundary")
# rows = cursor.fetchall()
# assert rows == [["A" * 4096], ["📝" * 4096]]
finally:
pass


def test_nvarchar_max_chunk_edge(cursor, db_connection):
"""Test NVARCHAR(MAX) insert slightly larger than a chunk."""
try:
cursor.execute("DROP TABLE IF EXISTS #pytest_nvarchar_chunk")
cursor.execute("CREATE TABLE #pytest_nvarchar_chunk (col NVARCHAR(MAX))")
db_connection.commit()

chunk_size = 8192 # bytes
test_str = "📝" * ((chunk_size // 4) + 3) # slightly > 1 chunk
cursor.execute("INSERT INTO #pytest_nvarchar_chunk (col) VALUES (?)", [test_str])
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_nvarchar_chunk")
# row = cursor.fetchone()
# assert row[0] == test_str
finally:
pass


def test_close(db_connection):
"""Test closing the cursor"""
Expand Down
Loading