Skip to content

Commit 1c78d74

Browse files
committed
FEAT: Access Token Login
1 parent 00a7c35 commit 1c78d74

File tree

4 files changed

+86
-53
lines changed

4 files changed

+86
-53
lines changed

mssql_python/connection.py

Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Connection:
3333
close() -> None:
3434
"""
3535

36-
def __init__(self, connection_str: str, autocommit: bool = False, attrs_before: dict = {}, **kwargs) -> None:
36+
def __init__(self, connection_str: str, autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None:
3737
"""
3838
Initialize the connection object with the specified connection string and parameters.
3939
@@ -78,29 +78,23 @@ def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
7878
# Add the driver attribute to the connection string
7979
conn_str = add_driver_to_connection_str(connection_str)
8080

81-
# Check if access token authentication is being used
82-
if "attrs_before" in kwargs:
83-
# Skip adding Uid and Pwd for access token authentication
84-
if ENABLE_LOGGING:
85-
logger.info("Using access token authentication. Skipping Uid and Pwd.")
86-
else:
87-
# Add additional key-value pairs to the connection string
88-
for key, value in kwargs.items():
89-
if key.lower() == "host":
90-
key = "Server"
91-
elif key.lower() == "user":
92-
key = "Uid"
93-
elif key.lower() == "password":
94-
key = "Pwd"
95-
elif key.lower() == "database":
96-
key = "Database"
97-
elif key.lower() == "encrypt":
98-
key = "Encrypt"
99-
elif key.lower() == "trust_server_certificate":
100-
key = "TrustServerCertificate"
101-
else:
102-
continue
103-
conn_str += f"{key}={value};"
81+
# Add additional key-value pairs to the connection string
82+
for key, value in kwargs.items():
83+
if key.lower() == "host":
84+
key = "Server"
85+
elif key.lower() == "user":
86+
key = "Uid"
87+
elif key.lower() == "password":
88+
key = "Pwd"
89+
elif key.lower() == "database":
90+
key = "Database"
91+
elif key.lower() == "encrypt":
92+
key = "Encrypt"
93+
elif key.lower() == "trust_server_certificate":
94+
key = "TrustServerCertificate"
95+
else:
96+
continue
97+
conn_str += f"{key}={value};"
10498

10599
if ENABLE_LOGGING:
106100
logger.info("Final connection string: %s", conn_str)
@@ -136,29 +130,68 @@ def _initializer(self) -> None:
136130
)
137131
self._connect_to_db()
138132

139-
140133
def _apply_attrs_before(self):
141134
"""
142-
Apply a dictionary of attributes to the database connection before connecting.
135+
Apply specific pre-connection attributes.
136+
Currently, this method only processes an attribute with key 1256 (e.g., SQL_COPT_SS_ACCESS_TOKEN)
137+
if present in `self._attrs_before`. Other attributes are ignored.
143138
144139
Returns:
145-
bool: True if all attributes were successfully applied, False otherwise.
140+
bool: True.
146141
"""
147-
strencoding = "utf-16le"
148142

149143
if ENABLE_LOGGING:
150-
logger.info("Applying attrs_before: %s", self._attrs_before)
144+
logger.info("Attempting to apply pre-connection attributes (attrs_before): %s", self._attrs_before)
145+
146+
if not isinstance(self._attrs_before, dict):
147+
if self._attrs_before is not None and ENABLE_LOGGING:
148+
logger.warning(
149+
f"_attrs_before is of type {type(self._attrs_before).__name__}, "
150+
f"expected dict. Skipping attribute application."
151+
)
152+
elif self._attrs_before is None and ENABLE_LOGGING:
153+
logger.debug("_attrs_before is None. No pre-connection attributes to apply.")
154+
return True # Exit if _attrs_before is not a dictionary or is None
151155

152156
for key, value in self._attrs_before.items():
157+
ikey = None
153158
if isinstance(key, int):
154159
ikey = key
155160
elif isinstance(key, str) and key.isdigit():
156-
ikey = int(key)
161+
try:
162+
ikey = int(key)
163+
except ValueError:
164+
if ENABLE_LOGGING:
165+
logger.debug(
166+
f"Skipping attribute with key '{key}' in attrs_before: "
167+
f"could not convert string to int."
168+
)
169+
continue # Skip if string key is not a valid integer
157170
else:
158-
raise TypeError(f"Unsupported key type: {type(key).__name__}")
159-
160-
self._set_connection_attributes(ikey, value)
161-
171+
if ENABLE_LOGGING:
172+
logger.debug(
173+
f"Skipping attribute with key '{key}' in attrs_before due to "
174+
f"unsupported key type: {type(key).__name__}. Expected int or string representation of an int."
175+
)
176+
continue # Skip keys that are not int or string representation of an int
177+
178+
if ikey == ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value:
179+
if ENABLE_LOGGING:
180+
logger.info(
181+
f"Found attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value}. Attempting to set it."
182+
)
183+
self._set_connection_attributes(ikey, value)
184+
if ENABLE_LOGGING:
185+
logger.info(
186+
f"Call to set attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value} with value '{value}' completed."
187+
)
188+
# If you expect only one such key, you could add 'break' here.
189+
else:
190+
if ENABLE_LOGGING:
191+
logger.debug(
192+
f"Ignoring attribute with key '{key}' (resolved to {ikey}) in attrs_before "
193+
f"as it is not the target attribute ({ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value})."
194+
)
162195
return True
163196

164197
def _allocate_environment_handle(self):

mssql_python/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,4 @@ class ConstantsDDBC(Enum):
117117
SQL_NULLABLE = 1
118118
SQL_MAX_NUMERIC_LEN = 16
119119
SQL_IS_POINTER = -4
120-
120+
SQL_COPT_SS_ACCESS_TOKEN = 1256

mssql_python/db_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mssql_python.connection import Connection
77

88

9-
def connect(connection_str: str, autocommit: bool = True, attrs_before: dict = {}, **kwargs) -> Connection:
9+
def connect(connection_str: str, autocommit: bool = True, attrs_before: dict = None, **kwargs) -> Connection:
1010
"""
1111
Constructor for creating a connection to the database.
1212

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -710,30 +710,30 @@ SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attri
710710
int intValue = ValuePtr.cast<int>();
711711
value = reinterpret_cast<SQLPOINTER>(intValue);
712712
length = SQL_IS_INTEGER; // Integer values don't require a length
713-
} else if (py::isinstance<py::str>(ValuePtr)) {
714-
// Handle Unicode string values
715-
static std::wstring unicodeValueBuffer;
716-
unicodeValueBuffer = ValuePtr.cast<std::wstring>();
717-
value = const_cast<SQLWCHAR*>(unicodeValueBuffer.c_str());
718-
length = SQL_NTS; // Indicates null-terminated string
713+
// } else if (py::isinstance<py::str>(ValuePtr)) {
714+
// // Handle Unicode string values
715+
// static std::wstring unicodeValueBuffer;
716+
// unicodeValueBuffer = ValuePtr.cast<std::wstring>();
717+
// value = const_cast<SQLWCHAR*>(unicodeValueBuffer.c_str());
718+
// length = SQL_NTS; // Indicates null-terminated string
719719
} else if (py::isinstance<py::bytes>(ValuePtr) || py::isinstance<py::bytearray>(ValuePtr)) {
720720
// Handle byte or bytearray values (like access tokens)
721721
// Store in static buffer to ensure memory remains valid during connection
722722
static std::vector<std::string> bytesBuffers;
723723
bytesBuffers.push_back(ValuePtr.cast<std::string>());
724724
value = const_cast<char*>(bytesBuffers.back().c_str());
725725
length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token)
726-
} else if (py::isinstance<py::list>(ValuePtr) || py::isinstance<py::tuple>(ValuePtr)) {
727-
// Handle list or tuple values
728-
LOG("ValuePtr is a sequence (list or tuple)");
729-
for (py::handle item : ValuePtr) {
730-
LOG("Processing item in sequence");
731-
SQLRETURN ret = SQLSetConnectAttr_wrap(ConnectionHandle, Attribute, py::reinterpret_borrow<py::object>(item));
732-
if (!SQL_SUCCEEDED(ret)) {
733-
LOG("Failed to set attribute for item in sequence");
734-
return ret;
735-
}
736-
}
726+
// } else if (py::isinstance<py::list>(ValuePtr) || py::isinstance<py::tuple>(ValuePtr)) {
727+
// // Handle list or tuple values
728+
// LOG("ValuePtr is a sequence (list or tuple)");
729+
// for (py::handle item : ValuePtr) {
730+
// LOG("Processing item in sequence");
731+
// SQLRETURN ret = SQLSetConnectAttr_wrap(ConnectionHandle, Attribute, py::reinterpret_borrow<py::object>(item));
732+
// if (!SQL_SUCCEEDED(ret)) {
733+
// LOG("Failed to set attribute for item in sequence");
734+
// return ret;
735+
// }
736+
// }
737737
} else {
738738
LOG("Unsupported ValuePtr type");
739739
return SQL_ERROR;

0 commit comments

Comments
 (0)