Skip to content

Commit

Permalink
fix(python): Make the SQLAlchemy connection check more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Oct 17, 2024
1 parent 3255066 commit 7f14e4f
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _inject_type_overrides(

@staticmethod
def _is_alchemy_async(conn: Any) -> bool:
"""Check if the cursor/connection/session object is async."""
"""Check if the given connection is SQLALchemy async."""
try:
from sqlalchemy.ext.asyncio import (
AsyncConnection,
Expand All @@ -352,7 +352,7 @@ def _is_alchemy_async(conn: Any) -> bool:

@staticmethod
def _is_alchemy_engine(conn: Any) -> bool:
"""Check if the cursor/connection/session object is async."""
"""Check if the given connection is a SQLAlchemy Engine."""
from sqlalchemy.engine import Engine

if isinstance(conn, Engine):
Expand All @@ -364,9 +364,14 @@ def _is_alchemy_engine(conn: Any) -> bool:
except ImportError:
return False

@staticmethod
def _is_alchemy_object(conn: Any) -> bool:
"""Check if the given connection is a SQLAlchemy object (of any kind)."""
return type(conn).__module__.split(".", 1)[0] == "sqlalchemy"

@staticmethod
def _is_alchemy_session(conn: Any) -> bool:
"""Check if the cursor/connection/session object is async."""
"""Check if the given connection is a SQLAlchemy Session object."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, sessionmaker

Expand Down Expand Up @@ -482,7 +487,7 @@ def execute(

options = options or {}

if self.driver_name == "sqlalchemy":
if self._is_alchemy_object(self.cursor):
cursor_execute, options, query = self._sqlalchemy_setup(query, options)
else:
cursor_execute = self.cursor.execute
Expand Down

0 comments on commit 7f14e4f

Please sign in to comment.