Skip to content

Commit 056c810

Browse files
committed
resolved review comments
1 parent aa9bfbc commit 056c810

File tree

4 files changed

+130
-50
lines changed

4 files changed

+130
-50
lines changed

โ€Žmssql_python/cursor.pyโ€Ž

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,14 @@ def _get_numeric_data(self, param):
216216
numeric_data.val = val
217217
return numeric_data
218218

219+
def _calculate_utf16_length(self, param: str) -> int:
220+
"""Return UTF-16 code unit length of a Python string."""
221+
try:
222+
return len(param.encode("utf-16-le")) // 2
223+
except UnicodeEncodeError as e:
224+
log('warning', "UTF-16 encoding failed for %r: %s. Falling back to len().", param, e)
225+
return len(param)
226+
219227
def _map_sql_type(self, param, parameters_list, i):
220228
"""
221229
Map a Python data type to the corresponding SQL type,
@@ -332,7 +340,7 @@ def _map_sql_type(self, param, parameters_list, i):
332340
# TODO: revisit
333341
if len(param) > 4000: # Long strings
334342
if is_unicode:
335-
utf16_len = len(param.encode("utf-16-le")) // 2
343+
utf16_len = self._calculate_utf16_length(param)
336344
return (
337345
ddbc_sql_const.SQL_WLONGVARCHAR.value,
338346
ddbc_sql_const.SQL_C_WCHAR.value,
@@ -346,7 +354,7 @@ def _map_sql_type(self, param, parameters_list, i):
346354
0,
347355
)
348356
if is_unicode: # Short Unicode strings
349-
utf16_len = len(param.encode("utf-16-le")) // 2
357+
utf16_len = self._calculate_utf16_length(param)
350358
return (
351359
ddbc_sql_const.SQL_WVARCHAR.value,
352360
ddbc_sql_const.SQL_C_WCHAR.value,

โ€Žmssql_python/pybind/ddbc_bindings.cppโ€Ž

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,16 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
275275
AllocateParamBuffer<std::vector<SQLWCHAR>>(paramBuffers);
276276

277277
// Reserve space and convert from wstring to SQLWCHAR array
278-
sqlwcharBuffer->resize(strParam->size() + 1, 0); // +1 for null terminator
279278
std::vector<SQLWCHAR> utf16 = WStringToSQLWCHAR(*strParam);
280-
sqlwcharBuffer->assign(utf16.begin(), utf16.end());
281-
279+
if (utf16.size() < strParam->size()) {
280+
LOG("Warning: UTF-16 encoding shrank string? input={} output={}",
281+
strParam->size(), utf16.size());
282+
}
283+
if (utf16.size() > strParam->size() * 2 + 1) {
284+
LOG("Warning: UTF-16 expansion unusually large: input={} output={}",
285+
strParam->size(), utf16.size());
286+
}
287+
*sqlwcharBuffer = std::move(utf16);
282288
// Use the SQLWCHAR buffer instead of the wstring directly
283289
dataPtr = sqlwcharBuffer->data();
284290
bufferLength = sqlwcharBuffer->size() * sizeof(SQLWCHAR);
@@ -1704,6 +1710,12 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
17041710
// SQLGetData will null-terminate the data
17051711
#if defined(__APPLE__) || defined(__linux__)
17061712
auto raw_bytes = reinterpret_cast<const char*>(dataBuffer.data());
1713+
size_t actualBufferSize = dataBuffer.size() * sizeof(SQLWCHAR);
1714+
if (dataLen < 0 || static_cast<size_t>(dataLen) > actualBufferSize) {
1715+
LOG("Error: py::bytes creation request exceeds buffer size. dataLen={} buffer={}",
1716+
dataLen, actualBufferSize);
1717+
ThrowStdException("Invalid buffer length for py::bytes");
1718+
}
17071719
py::bytes py_bytes(raw_bytes, dataLen);
17081720
py::str decoded = py_bytes.attr("decode")("utf-16-le");
17091721
row.append(decoded);

โ€Žmssql_python/pybind/ddbc_bindings.hโ€Ž

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,49 +33,107 @@ using namespace pybind11::literals;
3333
#include <sqlext.h>
3434

3535
#if defined(__APPLE__) || defined(__linux__)
36-
// macOS-specific headers
37-
#include <dlfcn.h>
36+
#include <dlfcn.h>
37+
38+
// Unicode constants for surrogate ranges and max scalar value
39+
constexpr uint32_t UNICODE_SURROGATE_HIGH_START = 0xD800;
40+
constexpr uint32_t UNICODE_SURROGATE_HIGH_END = 0xDBFF;
41+
constexpr uint32_t UNICODE_SURROGATE_LOW_START = 0xDC00;
42+
constexpr uint32_t UNICODE_SURROGATE_LOW_END = 0xDFFF;
43+
constexpr uint32_t UNICODE_MAX_CODEPOINT = 0x10FFFF;
44+
constexpr uint32_t UNICODE_REPLACEMENT_CHAR = 0xFFFD;
45+
46+
// Validate whether a code point is a legal Unicode scalar value
47+
// (excludes surrogate halves and values beyond U+10FFFF)
48+
inline bool IsValidUnicodeScalar(uint32_t cp) {
49+
return cp <= UNICODE_MAX_CODEPOINT &&
50+
!(cp >= UNICODE_SURROGATE_HIGH_START && cp <= UNICODE_SURROGATE_LOW_END);
51+
}
3852

39-
inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) {
40-
if (!sqlwStr) return std::wstring();
53+
inline std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) {
54+
if (!sqlwStr) return std::wstring();
4155

42-
if (length == SQL_NTS) {
43-
size_t i = 0;
44-
while (sqlwStr[i] != 0) ++i;
45-
length = i;
46-
}
56+
if (length == SQL_NTS) {
57+
size_t i = 0;
58+
while (sqlwStr[i] != 0) ++i;
59+
length = i;
60+
}
61+
std::wstring result;
62+
result.reserve(length);
4763

48-
std::wstring result;
49-
result.reserve(length);
64+
if constexpr (sizeof(SQLWCHAR) == 2) {
65+
// Decode UTF-16 to UTF-32 (with surrogate pair handling)
66+
for (size_t i = 0; i < length; ++i) {
67+
uint16_t wc = static_cast<uint16_t>(sqlwStr[i]);
68+
// Check if this is a high surrogate (U+D800โ€“U+DBFF)
69+
if (wc >= UNICODE_SURROGATE_HIGH_START && wc <= UNICODE_SURROGATE_HIGH_END && i + 1 < length) {
70+
uint16_t low = static_cast<uint16_t>(sqlwStr[i + 1]);
71+
// Check if the next code unit is a low surrogate (U+DC00โ€“U+DFFF)
72+
if (low >= UNICODE_SURROGATE_LOW_START && low <= UNICODE_SURROGATE_LOW_END) {
73+
// Combine surrogate pair into a single code point
74+
uint32_t cp = (((wc - UNICODE_SURROGATE_HIGH_START) << 10) | (low - UNICODE_SURROGATE_LOW_START)) + 0x10000;
75+
result.push_back(static_cast<wchar_t>(cp));
76+
++i; // Skip the low surrogate
77+
continue;
78+
}
79+
}
80+
// If valid scalar then append, else append replacement char (U+FFFD)
81+
if (IsValidUnicodeScalar(wc)) {
82+
result.push_back(static_cast<wchar_t>(wc));
83+
} else {
84+
result.push_back(static_cast<wchar_t>(UNICODE_REPLACEMENT_CHAR));
85+
}
86+
}
87+
} else {
88+
// SQLWCHAR is UTF-32, so just copy with validation
5089
for (size_t i = 0; i < length; ++i) {
51-
result.push_back(static_cast<wchar_t>(sqlwStr[i]));
90+
uint32_t cp = static_cast<uint32_t>(sqlwStr[i]);
91+
if (IsValidUnicodeScalar(cp)) {
92+
result.push_back(static_cast<wchar_t>(cp));
93+
} else {
94+
result.push_back(static_cast<wchar_t>(UNICODE_REPLACEMENT_CHAR));
95+
}
5296
}
53-
return result;
5497
}
98+
return result;
99+
}
55100

56-
inline std::vector<SQLWCHAR> WStringToSQLWCHAR(const std::wstring& str) {
101+
inline std::vector<SQLWCHAR> WStringToSQLWCHAR(const std::wstring& str) {
57102
std::vector<SQLWCHAR> result;
58-
59-
for (wchar_t wc : str) {
60-
uint32_t codePoint = static_cast<uint32_t>(wc);
61-
if (codePoint >= 0xD800 && codePoint <= 0xDFFF) {
62-
// Skip invalid lone surrogates (shouldn't occur in well-formed wchar_t strings)
63-
continue;
64-
} else if (codePoint <= 0xFFFF) {
65-
result.push_back(static_cast<SQLWCHAR>(codePoint));
66-
} else if (codePoint <= 0x10FFFF) {
67-
// Encode as surrogate pair
68-
codePoint -= 0x10000;
69-
SQLWCHAR highSurrogate = static_cast<SQLWCHAR>((codePoint >> 10) + 0xD800);
70-
SQLWCHAR lowSurrogate = static_cast<SQLWCHAR>((codePoint & 0x3FF) + 0xDC00);
71-
result.push_back(highSurrogate);
72-
result.push_back(lowSurrogate);
103+
result.reserve(str.size() + 2);
104+
if constexpr (sizeof(SQLWCHAR) == 2) {
105+
// Encode UTF-32 to UTF-16
106+
for (wchar_t wc : str) {
107+
uint32_t cp = static_cast<uint32_t>(wc);
108+
if (!IsValidUnicodeScalar(cp)) {
109+
cp = UNICODE_REPLACEMENT_CHAR;
110+
}
111+
if (cp <= 0xFFFF) {
112+
// Fits in a single UTF-16 code unit
113+
result.push_back(static_cast<SQLWCHAR>(cp));
114+
} else {
115+
// Encode as surrogate pair
116+
cp -= 0x10000;
117+
SQLWCHAR high = static_cast<SQLWCHAR>((cp >> 10) + UNICODE_SURROGATE_HIGH_START);
118+
SQLWCHAR low = static_cast<SQLWCHAR>((cp & 0x3FF) + UNICODE_SURROGATE_LOW_START);
119+
result.push_back(high);
120+
result.push_back(low);
121+
}
122+
}
123+
} else {
124+
// Encode UTF-32 directly
125+
for (wchar_t wc : str) {
126+
uint32_t cp = static_cast<uint32_t>(wc);
127+
if (IsValidUnicodeScalar(cp)) {
128+
result.push_back(static_cast<SQLWCHAR>(cp));
129+
} else {
130+
result.push_back(static_cast<SQLWCHAR>(UNICODE_REPLACEMENT_CHAR));
131+
}
73132
}
74133
}
75-
result.push_back(0); // Null terminator
134+
result.push_back(0); // null terminator
76135
return result;
77136
}
78-
79137
#endif
80138

81139
#if defined(__APPLE__) || defined(__linux__)

โ€Žtests/test_004_cursor.pyโ€Ž

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,22 +1315,24 @@ def test_row_column_mapping(cursor, db_connection):
13151315
cursor.execute("DROP TABLE #pytest_row_test")
13161316
db_connection.commit()
13171317

1318-
test_inputs = [
1319-
"Hello ๐Ÿ˜„",
1320-
"Flags ๐Ÿ‡ฎ๐Ÿ‡ณ๐Ÿ‡บ๐Ÿ‡ธ",
1321-
"Family ๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆ",
1322-
"Skin tone ๐Ÿ‘๐Ÿฝ",
1323-
"Brain ๐Ÿง ",
1324-
"Ice ๐ŸงŠ",
1325-
"Melting face ๐Ÿซ ",
1326-
"Accented รฉรผรฑรง",
1327-
"Chinese: ไธญๆ–‡",
1328-
"Japanese: ๆ—ฅๆœฌ่ชž",
1329-
]
1330-
13311318
def test_emoji_round_trip(cursor, db_connection):
13321319
"""Test round-trip of emoji and special characters"""
1333-
1320+
test_inputs = [
1321+
"Hello ๐Ÿ˜„",
1322+
"Flags ๐Ÿ‡ฎ๐Ÿ‡ณ๐Ÿ‡บ๐Ÿ‡ธ",
1323+
"Family ๐Ÿ‘จโ€๐Ÿ‘ฉโ€๐Ÿ‘งโ€๐Ÿ‘ฆ",
1324+
"Skin tone ๐Ÿ‘๐Ÿฝ",
1325+
"Brain ๐Ÿง ",
1326+
"Ice ๐ŸงŠ",
1327+
"Melting face ๐Ÿซ ",
1328+
"Accented รฉรผรฑรง",
1329+
"Chinese: ไธญๆ–‡",
1330+
"Japanese: ๆ—ฅๆœฌ่ชž",
1331+
"Hello ๐Ÿš€ World",
1332+
"admin๐Ÿ”’user",
1333+
"1๐Ÿš€' OR '1'='1",
1334+
]
1335+
13341336
cursor.execute("""
13351337
CREATE TABLE #pytest_emoji_test (
13361338
id INT IDENTITY PRIMARY KEY,

0 commit comments

Comments
ย (0)