Skip to content

Commit e16af86

Browse files
author
Peng Ren
committed
Add supports for Superset
1 parent 91f206e commit e16af86

File tree

2 files changed

+497
-103
lines changed

2 files changed

+497
-103
lines changed

pymongosql/sqlalchemy_mongodb/sqlalchemy_dialect.py

Lines changed: 163 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,18 @@
77
88
Supports both SQLAlchemy 1.x and 2.x versions.
99
"""
10+
import logging
1011
from typing import Any, Dict, List, Optional, Tuple, Type
1112

13+
from sqlalchemy import pool, types
14+
from sqlalchemy.engine import default, url
15+
from sqlalchemy.sql import compiler
16+
from sqlalchemy.sql.sqltypes import NULLTYPE
17+
18+
import pymongosql
19+
20+
_logger = logging.getLogger(__name__)
21+
1222
try:
1323
import sqlalchemy
1424

@@ -18,11 +28,6 @@
1828
SQLALCHEMY_VERSION = (1, 4) # Default fallback
1929
SQLALCHEMY_2X = False
2030

21-
from sqlalchemy import pool, types
22-
from sqlalchemy.engine import default, url
23-
from sqlalchemy.sql import compiler
24-
from sqlalchemy.sql.sqltypes import NULLTYPE
25-
2631
# Version-specific imports
2732
if SQLALCHEMY_2X:
2833
try:
@@ -33,8 +38,6 @@
3338
else:
3439
from sqlalchemy.engine.interfaces import Dialect
3540

36-
import pymongosql
37-
3841

3942
class PyMongoSQLIdentifierPreparer(compiler.IdentifierPreparer):
4043
"""MongoDB-specific identifier preparer.
@@ -274,34 +277,52 @@ def create_connect_args(self, url: url.URL) -> Tuple[List[Any], Dict[str, Any]]:
274277

275278
def get_schema_names(self, connection, **kwargs):
276279
"""Get list of databases (schemas in SQL terms)."""
277-
# In MongoDB, databases are like schemas
278-
cursor = connection.execute("SHOW DATABASES")
279-
return [row[0] for row in cursor.fetchall()]
280+
# Use MongoDB admin command directly instead of SQL SHOW DATABASES
281+
try:
282+
# Access the underlying MongoDB client through the connection
283+
db_connection = connection.connection
284+
if hasattr(db_connection, "_client"):
285+
admin_db = db_connection._client.admin
286+
result = admin_db.command("listDatabases")
287+
return [db["name"] for db in result.get("databases", [])]
288+
except Exception as e:
289+
_logger.warning(f"Failed to get database names: {e}")
290+
return ["default"] # Fallback to default database
280291

281292
def has_table(self, connection, table_name: str, schema: Optional[str] = None, **kwargs) -> bool:
282293
"""Check if a collection (table) exists."""
283294
try:
284-
if schema:
285-
sql = f"SHOW COLLECTIONS FROM {schema}"
286-
else:
287-
sql = "SHOW COLLECTIONS"
288-
cursor = connection.execute(sql)
289-
collections = [row[0] for row in cursor.fetchall()]
290-
return table_name in collections
291-
except Exception:
292-
return False
295+
# Use MongoDB listCollections command directly
296+
db_connection = connection.connection
297+
if hasattr(db_connection, "_client"):
298+
if schema:
299+
db = db_connection._client[schema]
300+
else:
301+
db = db_connection.database
302+
303+
# Use listCollections command
304+
collections = db.list_collection_names()
305+
return table_name in collections
306+
except Exception as e:
307+
_logger.warning(f"Failed to check table existence: {e}")
308+
return False
293309

294310
def get_table_names(self, connection, schema: Optional[str] = None, **kwargs) -> List[str]:
295311
"""Get list of collections (tables)."""
296312
try:
297-
if schema:
298-
sql = f"SHOW COLLECTIONS FROM {schema}"
299-
else:
300-
sql = "SHOW COLLECTIONS"
301-
cursor = connection.execute(sql)
302-
return [row[0] for row in cursor.fetchall()]
303-
except Exception:
304-
return []
313+
# Use MongoDB listCollections command directly
314+
db_connection = connection.connection
315+
if hasattr(db_connection, "_client"):
316+
if schema:
317+
db = db_connection._client[schema]
318+
else:
319+
db = db_connection.database
320+
321+
# Use listCollections command
322+
return db.list_collection_names()
323+
except Exception as e:
324+
_logger.warning(f"Failed to get table names: {e}")
325+
return []
305326

306327
def get_columns(self, connection, table_name: str, schema: Optional[str] = None, **kwargs) -> List[Dict[str, Any]]:
307328
"""Get column information for a collection.
@@ -310,23 +331,49 @@ def get_columns(self, connection, table_name: str, schema: Optional[str] = None,
310331
"""
311332
columns = []
312333
try:
313-
# Use DESCRIBE-like functionality if available
314-
if schema:
315-
sql = f"DESCRIBE {schema}.{table_name}"
316-
else:
317-
sql = f"DESCRIBE {table_name}"
318-
319-
cursor = connection.execute(sql)
320-
for row in cursor.fetchall():
321-
# Assume row format: (name, type, nullable, default)
322-
col_info = {
323-
"name": row[0],
324-
"type": self._get_column_type(row[1] if len(row) > 1 else "object"),
325-
"nullable": row[2] if len(row) > 2 else True,
326-
"default": row[3] if len(row) > 3 else None,
327-
}
328-
columns.append(col_info)
329-
except Exception:
334+
# Use direct MongoDB operations to sample documents and infer schema
335+
db_connection = connection.connection
336+
if hasattr(db_connection, "_client"):
337+
if schema:
338+
db = db_connection._client[schema]
339+
else:
340+
db = db_connection.database
341+
342+
collection = db[table_name]
343+
344+
# Sample a few documents to infer schema
345+
sample_docs = list(collection.find().limit(10))
346+
if sample_docs:
347+
# Collect all unique field names and types
348+
field_types = {}
349+
for doc in sample_docs:
350+
for field_name, value in doc.items():
351+
if field_name not in field_types:
352+
field_types[field_name] = self._infer_bson_type(value)
353+
354+
# Convert to SQLAlchemy column format
355+
for field_name, bson_type in field_types.items():
356+
columns.append(
357+
{
358+
"name": field_name,
359+
"type": self._get_column_type(bson_type),
360+
"nullable": field_name != "_id", # _id is always required
361+
"default": None,
362+
}
363+
)
364+
else:
365+
# Empty collection, provide minimal _id column
366+
columns = [
367+
{
368+
"name": "_id",
369+
"type": types.String(),
370+
"nullable": False,
371+
"default": None,
372+
}
373+
]
374+
375+
except Exception as e:
376+
_logger.warning(f"Failed to get column info for {table_name}: {e}")
330377
# Fallback: provide minimal _id column
331378
columns = [
332379
{
@@ -339,6 +386,33 @@ def get_columns(self, connection, table_name: str, schema: Optional[str] = None,
339386

340387
return columns
341388

389+
def _infer_bson_type(self, value: Any) -> str:
390+
"""Infer BSON type from a Python value."""
391+
from datetime import datetime
392+
393+
from bson import ObjectId
394+
395+
if isinstance(value, ObjectId):
396+
return "objectId"
397+
elif isinstance(value, str):
398+
return "string"
399+
elif isinstance(value, bool):
400+
return "bool"
401+
elif isinstance(value, int):
402+
return "int"
403+
elif isinstance(value, float):
404+
return "double"
405+
elif isinstance(value, datetime):
406+
return "date"
407+
elif isinstance(value, list):
408+
return "array"
409+
elif isinstance(value, dict):
410+
return "object"
411+
elif value is None:
412+
return "null"
413+
else:
414+
return "string" # Default fallback
415+
342416
def _get_column_type(self, mongo_type: str) -> Type[types.TypeEngine]:
343417
"""Map MongoDB/BSON types to SQLAlchemy types."""
344418
type_map = {
@@ -377,22 +451,32 @@ def get_indexes(self, connection, table_name: str, schema: Optional[str] = None,
377451
"""Get index information for a collection."""
378452
indexes = []
379453
try:
380-
if schema:
381-
sql = f"SHOW INDEXES FROM {schema}.{table_name}"
382-
else:
383-
sql = f"SHOW INDEXES FROM {table_name}"
384-
385-
cursor = connection.execute(sql)
386-
for row in cursor.fetchall():
387-
# Assume row format: (name, column_names, unique)
388-
index_info = {
389-
"name": row[0],
390-
"column_names": [row[1]] if isinstance(row[1], str) else row[1],
391-
"unique": row[2] if len(row) > 2 else False,
392-
}
393-
indexes.append(index_info)
394-
except Exception:
395-
# Always include the default _id index
454+
# Use direct MongoDB operations to get indexes
455+
db_connection = connection.connection
456+
if hasattr(db_connection, "_client"):
457+
if schema:
458+
db = db_connection._client[schema]
459+
else:
460+
db = db_connection.database
461+
462+
collection = db[table_name]
463+
464+
# Get index information
465+
index_info = collection.index_information()
466+
for index_name, index_spec in index_info.items():
467+
# Extract column names from key specification
468+
column_names = [field[0] for field in index_spec.get("key", [])]
469+
470+
indexes.append(
471+
{
472+
"name": index_name,
473+
"column_names": column_names,
474+
"unique": index_spec.get("unique", False),
475+
}
476+
)
477+
except Exception as e:
478+
_logger.warning(f"Failed to get index info for {table_name}: {e}")
479+
# Always include the default _id index as fallback
396480
indexes = [
397481
{
398482
"name": "_id_",
@@ -431,6 +515,25 @@ def do_commit(self, dbapi_connection):
431515
# This is normal behavior for MongoDB connections
432516
pass
433517

518+
def do_ping(self, dbapi_connection):
519+
"""Ping the database to test connection status.
520+
521+
Used by SQLAlchemy and tools like Superset for connection testing.
522+
This avoids the need to execute "SELECT 1" queries that would fail
523+
due to PartiQL grammar requirements.
524+
"""
525+
if hasattr(dbapi_connection, "test_connection") and callable(dbapi_connection.test_connection):
526+
return dbapi_connection.test_connection()
527+
else:
528+
# Fallback: try to execute a simple ping command directly
529+
try:
530+
if hasattr(dbapi_connection, "_client"):
531+
dbapi_connection._client.admin.command("ping")
532+
return True
533+
except Exception:
534+
pass
535+
return False
536+
434537

435538
# Version information
436539
__sqlalchemy_version__ = SQLALCHEMY_VERSION

0 commit comments

Comments
 (0)