77
88Supports both SQLAlchemy 1.x and 2.x versions.
99"""
10+ import logging
1011from typing import Any , Dict , List , Optional , Tuple , Type
1112
13+ from sqlalchemy import pool , types
14+ from sqlalchemy .engine import default , url
15+ from sqlalchemy .sql import compiler
16+ from sqlalchemy .sql .sqltypes import NULLTYPE
17+
18+ import pymongosql
19+
20+ _logger = logging .getLogger (__name__ )
21+
1222try :
1323 import sqlalchemy
1424
1828 SQLALCHEMY_VERSION = (1 , 4 ) # Default fallback
1929 SQLALCHEMY_2X = False
2030
21- from sqlalchemy import pool , types
22- from sqlalchemy .engine import default , url
23- from sqlalchemy .sql import compiler
24- from sqlalchemy .sql .sqltypes import NULLTYPE
25-
2631# Version-specific imports
2732if SQLALCHEMY_2X :
2833 try :
3338else :
3439 from sqlalchemy .engine .interfaces import Dialect
3540
36- import pymongosql
37-
3841
3942class PyMongoSQLIdentifierPreparer (compiler .IdentifierPreparer ):
4043 """MongoDB-specific identifier preparer.
@@ -274,34 +277,52 @@ def create_connect_args(self, url: url.URL) -> Tuple[List[Any], Dict[str, Any]]:
274277
275278 def get_schema_names (self , connection , ** kwargs ):
276279 """Get list of databases (schemas in SQL terms)."""
277- # In MongoDB, databases are like schemas
278- cursor = connection .execute ("SHOW DATABASES" )
279- return [row [0 ] for row in cursor .fetchall ()]
280+ # Use MongoDB admin command directly instead of SQL SHOW DATABASES
281+ try :
282+ # Access the underlying MongoDB client through the connection
283+ db_connection = connection .connection
284+ if hasattr (db_connection , "_client" ):
285+ admin_db = db_connection ._client .admin
286+ result = admin_db .command ("listDatabases" )
287+ return [db ["name" ] for db in result .get ("databases" , [])]
288+ except Exception as e :
289+ _logger .warning (f"Failed to get database names: { e } " )
290+ return ["default" ] # Fallback to default database
280291
281292 def has_table (self , connection , table_name : str , schema : Optional [str ] = None , ** kwargs ) -> bool :
282293 """Check if a collection (table) exists."""
283294 try :
284- if schema :
285- sql = f"SHOW COLLECTIONS FROM { schema } "
286- else :
287- sql = "SHOW COLLECTIONS"
288- cursor = connection .execute (sql )
289- collections = [row [0 ] for row in cursor .fetchall ()]
290- return table_name in collections
291- except Exception :
292- return False
295+ # Use MongoDB listCollections command directly
296+ db_connection = connection .connection
297+ if hasattr (db_connection , "_client" ):
298+ if schema :
299+ db = db_connection ._client [schema ]
300+ else :
301+ db = db_connection .database
302+
303+ # Use listCollections command
304+ collections = db .list_collection_names ()
305+ return table_name in collections
306+ except Exception as e :
307+ _logger .warning (f"Failed to check table existence: { e } " )
308+ return False
293309
294310 def get_table_names (self , connection , schema : Optional [str ] = None , ** kwargs ) -> List [str ]:
295311 """Get list of collections (tables)."""
296312 try :
297- if schema :
298- sql = f"SHOW COLLECTIONS FROM { schema } "
299- else :
300- sql = "SHOW COLLECTIONS"
301- cursor = connection .execute (sql )
302- return [row [0 ] for row in cursor .fetchall ()]
303- except Exception :
304- return []
313+ # Use MongoDB listCollections command directly
314+ db_connection = connection .connection
315+ if hasattr (db_connection , "_client" ):
316+ if schema :
317+ db = db_connection ._client [schema ]
318+ else :
319+ db = db_connection .database
320+
321+ # Use listCollections command
322+ return db .list_collection_names ()
323+ except Exception as e :
324+ _logger .warning (f"Failed to get table names: { e } " )
325+ return []
305326
306327 def get_columns (self , connection , table_name : str , schema : Optional [str ] = None , ** kwargs ) -> List [Dict [str , Any ]]:
307328 """Get column information for a collection.
@@ -310,23 +331,49 @@ def get_columns(self, connection, table_name: str, schema: Optional[str] = None,
310331 """
311332 columns = []
312333 try :
313- # Use DESCRIBE-like functionality if available
314- if schema :
315- sql = f"DESCRIBE { schema } .{ table_name } "
316- else :
317- sql = f"DESCRIBE { table_name } "
318-
319- cursor = connection .execute (sql )
320- for row in cursor .fetchall ():
321- # Assume row format: (name, type, nullable, default)
322- col_info = {
323- "name" : row [0 ],
324- "type" : self ._get_column_type (row [1 ] if len (row ) > 1 else "object" ),
325- "nullable" : row [2 ] if len (row ) > 2 else True ,
326- "default" : row [3 ] if len (row ) > 3 else None ,
327- }
328- columns .append (col_info )
329- except Exception :
334+ # Use direct MongoDB operations to sample documents and infer schema
335+ db_connection = connection .connection
336+ if hasattr (db_connection , "_client" ):
337+ if schema :
338+ db = db_connection ._client [schema ]
339+ else :
340+ db = db_connection .database
341+
342+ collection = db [table_name ]
343+
344+ # Sample a few documents to infer schema
345+ sample_docs = list (collection .find ().limit (10 ))
346+ if sample_docs :
347+ # Collect all unique field names and types
348+ field_types = {}
349+ for doc in sample_docs :
350+ for field_name , value in doc .items ():
351+ if field_name not in field_types :
352+ field_types [field_name ] = self ._infer_bson_type (value )
353+
354+ # Convert to SQLAlchemy column format
355+ for field_name , bson_type in field_types .items ():
356+ columns .append (
357+ {
358+ "name" : field_name ,
359+ "type" : self ._get_column_type (bson_type ),
360+ "nullable" : field_name != "_id" , # _id is always required
361+ "default" : None ,
362+ }
363+ )
364+ else :
365+ # Empty collection, provide minimal _id column
366+ columns = [
367+ {
368+ "name" : "_id" ,
369+ "type" : types .String (),
370+ "nullable" : False ,
371+ "default" : None ,
372+ }
373+ ]
374+
375+ except Exception as e :
376+ _logger .warning (f"Failed to get column info for { table_name } : { e } " )
330377 # Fallback: provide minimal _id column
331378 columns = [
332379 {
@@ -339,6 +386,33 @@ def get_columns(self, connection, table_name: str, schema: Optional[str] = None,
339386
340387 return columns
341388
389+ def _infer_bson_type (self , value : Any ) -> str :
390+ """Infer BSON type from a Python value."""
391+ from datetime import datetime
392+
393+ from bson import ObjectId
394+
395+ if isinstance (value , ObjectId ):
396+ return "objectId"
397+ elif isinstance (value , str ):
398+ return "string"
399+ elif isinstance (value , bool ):
400+ return "bool"
401+ elif isinstance (value , int ):
402+ return "int"
403+ elif isinstance (value , float ):
404+ return "double"
405+ elif isinstance (value , datetime ):
406+ return "date"
407+ elif isinstance (value , list ):
408+ return "array"
409+ elif isinstance (value , dict ):
410+ return "object"
411+ elif value is None :
412+ return "null"
413+ else :
414+ return "string" # Default fallback
415+
342416 def _get_column_type (self , mongo_type : str ) -> Type [types .TypeEngine ]:
343417 """Map MongoDB/BSON types to SQLAlchemy types."""
344418 type_map = {
@@ -377,22 +451,32 @@ def get_indexes(self, connection, table_name: str, schema: Optional[str] = None,
377451 """Get index information for a collection."""
378452 indexes = []
379453 try :
380- if schema :
381- sql = f"SHOW INDEXES FROM { schema } .{ table_name } "
382- else :
383- sql = f"SHOW INDEXES FROM { table_name } "
384-
385- cursor = connection .execute (sql )
386- for row in cursor .fetchall ():
387- # Assume row format: (name, column_names, unique)
388- index_info = {
389- "name" : row [0 ],
390- "column_names" : [row [1 ]] if isinstance (row [1 ], str ) else row [1 ],
391- "unique" : row [2 ] if len (row ) > 2 else False ,
392- }
393- indexes .append (index_info )
394- except Exception :
395- # Always include the default _id index
454+ # Use direct MongoDB operations to get indexes
455+ db_connection = connection .connection
456+ if hasattr (db_connection , "_client" ):
457+ if schema :
458+ db = db_connection ._client [schema ]
459+ else :
460+ db = db_connection .database
461+
462+ collection = db [table_name ]
463+
464+ # Get index information
465+ index_info = collection .index_information ()
466+ for index_name , index_spec in index_info .items ():
467+ # Extract column names from key specification
468+ column_names = [field [0 ] for field in index_spec .get ("key" , [])]
469+
470+ indexes .append (
471+ {
472+ "name" : index_name ,
473+ "column_names" : column_names ,
474+ "unique" : index_spec .get ("unique" , False ),
475+ }
476+ )
477+ except Exception as e :
478+ _logger .warning (f"Failed to get index info for { table_name } : { e } " )
479+ # Always include the default _id index as fallback
396480 indexes = [
397481 {
398482 "name" : "_id_" ,
@@ -431,6 +515,25 @@ def do_commit(self, dbapi_connection):
431515 # This is normal behavior for MongoDB connections
432516 pass
433517
518+ def do_ping (self , dbapi_connection ):
519+ """Ping the database to test connection status.
520+
521+ Used by SQLAlchemy and tools like Superset for connection testing.
522+ This avoids the need to execute "SELECT 1" queries that would fail
523+ due to PartiQL grammar requirements.
524+ """
525+ if hasattr (dbapi_connection , "test_connection" ) and callable (dbapi_connection .test_connection ):
526+ return dbapi_connection .test_connection ()
527+ else :
528+ # Fallback: try to execute a simple ping command directly
529+ try :
530+ if hasattr (dbapi_connection , "_client" ):
531+ dbapi_connection ._client .admin .command ("ping" )
532+ return True
533+ except Exception :
534+ pass
535+ return False
536+
434537
435538# Version information
436539__sqlalchemy_version__ = SQLALCHEMY_VERSION
0 commit comments