Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class ConstantsDDBC(Enum):
SQL_FETCH_ABSOLUTE = 5
SQL_FETCH_RELATIVE = 6
SQL_FETCH_BOOKMARK = 8
SQL_DATETIMEOFFSET = -155
SQL_C_SS_TIMESTAMPOFFSET = 0x4001
SQL_SCOPE_CURROW = 0
SQL_BEST_ROWID = 1
SQL_ROWVER = 2
Expand Down
25 changes: 18 additions & 7 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,24 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
)

if isinstance(param, datetime.datetime):
return (
ddbc_sql_const.SQL_TIMESTAMP.value,
ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value,
26,
6,
False,
)
if param.tzinfo is not None:
# Timezone-aware datetime -> DATETIMEOFFSET
return (
ddbc_sql_const.SQL_DATETIMEOFFSET.value,
ddbc_sql_const.SQL_C_SS_TIMESTAMPOFFSET.value,
34,
7,
False,
)
else:
# Naive datetime -> TIMESTAMP
return (
ddbc_sql_const.SQL_TIMESTAMP.value,
ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value,
26,
6,
False,
)

if isinstance(param, datetime.date):
return (
Expand Down
101 changes: 98 additions & 3 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
#include <iostream>
#include <utility> // std::forward
#include <filesystem>

//-------------------------------------------------------------------------------------------------
// Macro definitions
//-------------------------------------------------------------------------------------------------

// This constant is not exposed via sql.h, hence define it here
#define SQL_SS_TIME2 (-154)

#define SQL_SS_TIMESTAMPOFFSET (-155)
#define SQL_C_SS_TIMESTAMPOFFSET (0x4001)
#define MAX_DIGITS_IN_NUMERIC 64

#define STRINGIFY_FOR_CASE(x) \
Expand Down Expand Up @@ -94,6 +94,20 @@ struct ColumnBuffers {
indicators(numCols, std::vector<SQLLEN>(fetchSize)) {}
};

// Struct to hold the DateTimeOffset structure
struct DateTimeOffset
{
SQLSMALLINT year;
SQLUSMALLINT month;
SQLUSMALLINT day;
SQLUSMALLINT hour;
SQLUSMALLINT minute;
SQLUSMALLINT second;
SQLUINTEGER fraction; // Nanoseconds
SQLSMALLINT timezone_hour; // Offset hours from UTC
SQLSMALLINT timezone_minute; // Offset minutes from UTC
};

//-------------------------------------------------------------------------------------------------
// Function pointer initialization
//-------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -463,6 +477,39 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
dataPtr = static_cast<void*>(sqlTimePtr);
break;
}
case SQL_C_SS_TIMESTAMPOFFSET: {
py::object datetimeType = py::module_::import("datetime").attr("datetime");
if (!py::isinstance(param, datetimeType)) {
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
}
// Checking if the object has a timezone
py::object tzinfo = param.attr("tzinfo");
if (tzinfo.is_none()) {
ThrowStdException("Datetime object must have tzinfo for SQL_C_SS_TIMESTAMPOFFSET at paramIndex " + std::to_string(paramIndex));
}

DateTimeOffset* dtoPtr = AllocateParamBuffer<DateTimeOffset>(paramBuffers);

dtoPtr->year = static_cast<SQLSMALLINT>(param.attr("year").cast<int>());
dtoPtr->month = static_cast<SQLUSMALLINT>(param.attr("month").cast<int>());
dtoPtr->day = static_cast<SQLUSMALLINT>(param.attr("day").cast<int>());
dtoPtr->hour = static_cast<SQLUSMALLINT>(param.attr("hour").cast<int>());
dtoPtr->minute = static_cast<SQLUSMALLINT>(param.attr("minute").cast<int>());
dtoPtr->second = static_cast<SQLUSMALLINT>(param.attr("second").cast<int>());
dtoPtr->fraction = static_cast<SQLUINTEGER>(param.attr("microsecond").cast<int>() * 1000);

py::object utcoffset = tzinfo.attr("utcoffset")(param);
int total_seconds = static_cast<int>(utcoffset.attr("total_seconds")().cast<double>());
std::div_t div_result = std::div(total_seconds, 3600);
dtoPtr->timezone_hour = static_cast<SQLSMALLINT>(div_result.quot);
dtoPtr->timezone_minute = static_cast<SQLSMALLINT>(div(div_result.rem, 60).quot);

dataPtr = static_cast<void*>(dtoPtr);
bufferLength = sizeof(DateTimeOffset);
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
*strLenOrIndPtr = bufferLength;
break;
}
case SQL_C_TYPE_TIMESTAMP: {
py::object datetimeType = py::module_::import("datetime").attr("datetime");
if (!py::isinstance(param, datetimeType)) {
Expand Down Expand Up @@ -514,7 +561,6 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
}
}
assert(SQLBindParameter_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr);

RETCODE rc = SQLBindParameter_ptr(
hStmt,
static_cast<SQLUSMALLINT>(paramIndex + 1), /* 1-based indexing */
Expand Down Expand Up @@ -2485,6 +2531,55 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
}
break;
}
case SQL_SS_TIMESTAMPOFFSET: {
DateTimeOffset dtoValue;
SQLLEN indicator;
ret = SQLGetData_ptr(
hStmt,
i, SQL_C_SS_TIMESTAMPOFFSET,
&dtoValue,
sizeof(dtoValue),
&indicator
);
if (SQL_SUCCEEDED(ret) && indicator != SQL_NULL_DATA) {
LOG("[Fetch] Retrieved DTO: {}-{}-{} {}:{}:{}, fraction(ns)={}, tz_hour={}, tz_minute={}",
dtoValue.year, dtoValue.month, dtoValue.day,
dtoValue.hour, dtoValue.minute, dtoValue.second,
dtoValue.fraction,
dtoValue.timezone_hour, dtoValue.timezone_minute
);

int totalMinutes = dtoValue.timezone_hour * 60 + dtoValue.timezone_minute;
// Validating offset
if (totalMinutes < -24 * 60 || totalMinutes > 24 * 60) {
std::ostringstream oss;
oss << "Invalid timezone offset from SQL_SS_TIMESTAMPOFFSET_STRUCT: "
<< totalMinutes << " minutes for column " << i;
ThrowStdException(oss.str());
}
// Convert fraction from ns to µs
int microseconds = dtoValue.fraction / 1000;
py::object datetime = py::module_::import("datetime");
py::object tzinfo = datetime.attr("timezone")(
datetime.attr("timedelta")(py::arg("minutes") = totalMinutes)
);
py::object py_dt = datetime.attr("datetime")(
dtoValue.year,
dtoValue.month,
dtoValue.day,
dtoValue.hour,
dtoValue.minute,
dtoValue.second,
microseconds,
tzinfo
);
row.append(py_dt);
} else {
LOG("Error fetching DATETIMEOFFSET for column {}, ret={}", i, ret);
row.append(py::none());
}
break;
}
case SQL_BINARY:
case SQL_VARBINARY:
case SQL_LONGVARBINARY: {
Expand Down
63 changes: 61 additions & 2 deletions tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import pytest
from datetime import datetime, date, time
from datetime import datetime, date, time, timedelta, timezone
import time as time_module
import decimal
from contextlib import closing
Expand Down Expand Up @@ -6470,7 +6470,7 @@ def test_only_null_and_empty_binary(cursor, db_connection):
finally:
drop_table_if_exists(cursor, "#pytest_null_empty_binary")
db_connection.commit()

# ---------------------- VARCHAR(MAX) ----------------------

def test_varcharmax_short_fetch(cursor, db_connection):
Expand Down Expand Up @@ -7356,6 +7356,65 @@ def test_decimal_separator_calculations(cursor, db_connection):
cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test")
db_connection.commit()

def test_datetimeoffset_read_write(cursor, db_connection):
"""
Test the driver's ability to correctly read and write DATETIMEOFFSET data,
including timezone information.
"""
try:
datetimeoffset_test_cases = [
(
"2023-10-26 10:30:00.0000000 +05:30",
datetime(2023, 10, 26, 10, 30, 0, 0,
tzinfo=timezone(timedelta(hours=5, minutes=30)))
),
(
"2023-10-27 15:45:10.1234567 -08:00",
datetime(2023, 10, 27, 15, 45, 10, 123456,
tzinfo=timezone(timedelta(hours=-8)))
),
(
"2023-10-28 20:00:05.9876543 +00:00",
datetime(2023, 10, 28, 20, 0, 5, 987654,
tzinfo=timezone(timedelta(hours=0)))
),
(
"invalid", # Placeholder for the SQL string
datetime(2023, 10, 29, 10, 0)
)
]
cursor.execute("IF OBJECT_ID('tempdb..#pytest_dto', 'U') IS NOT NULL DROP TABLE #pytest_dto;")
cursor.execute("CREATE TABLE #pytest_dto (id INT PRIMARY KEY, dto_column DATETIMEOFFSET);")
db_connection.commit()
insert_statement = "INSERT INTO #pytest_dto (id, dto_column) VALUES (?, ?);"
for i, (sql_str, python_dt) in enumerate(datetimeoffset_test_cases):
# Insert timezone-aware datetime objects
cursor.execute(insert_statement, i, python_dt)
db_connection.commit()

cursor.execute("SELECT id, dto_column FROM #pytest_dto ORDER BY id;")

for i, (sql_str, python_dt) in enumerate(datetimeoffset_test_cases):
if sql_str == "invalid":
continue

row = cursor.fetchone()
assert row is not None, f"No row fetched for test case {i}."

fetched_id, fetched_dto = row
assert fetched_dto.tzinfo is not None, "Fetched datetime object is naive."
expected_utc = python_dt.astimezone(timezone.utc).replace(tzinfo=None)
fetched_utc = fetched_dto.astimezone(timezone.utc).replace(tzinfo=None)
expected_utc = expected_utc.replace(microsecond=int(expected_utc.microsecond / 1000) * 1000)
fetched_utc = fetched_utc.replace(microsecond=int(fetched_utc.microsecond / 1000) * 1000)
assert fetched_utc == expected_utc, (
f"Value mismatch for test case {i}. "
f"Expected UTC: {expected_utc}, Got UTC: {fetched_utc}"
)
finally:
cursor.execute("IF OBJECT_ID('tempdb..#pytest_dto', 'U') IS NOT NULL DROP TABLE #pytest_dto;")
db_connection.commit()

def test_lowercase_attribute(cursor, db_connection):
"""Test that the lowercase attribute properly converts column names to lowercase"""

Expand Down