Skip to content

Commit 40a7f81

Browse files
author
Peng Ren
committed
Fix the type code issue
1 parent 573f976 commit 40a7f81

File tree

4 files changed

+174
-42
lines changed

4 files changed

+174
-42
lines changed

pymongosql/result_set.py

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,22 +430,142 @@ class PreProcessedResultSet(ResultSet):
430430
"""Result set for pre-formatted rows that don't need projection processing.
431431
432432
Used when rows are already projected and formatted (e.g., from SQLite intermediate storage).
433-
Skips the document projection step and goes directly to formatting as tuples.
433+
Skips the document projection step but applies type conversions for projection function results.
434434
"""
435435

436436
def _process_and_cache_batch(self, batch: List[Dict[str, Any]]) -> None:
437437
"""Process and cache a batch of pre-processed documents.
438438
439-
Unlike the base ResultSet, this skips projection processing and directly formats results.
439+
Unlike the base ResultSet, this skips projection processing but still applies type conversions
440+
for values that came from projection functions (e.g., converting strings back to datetime objects).
440441
"""
441442
if not batch:
442443
return
443-
# Skip projection processing - rows are already in final form
444-
# Just format to output format (tuple)
445-
formatted_batch = [self._format_result(doc) for doc in batch]
444+
# Skip full projection processing - rows are already in final form
445+
# But apply type conversions for projection function results
446+
converted_batch = [self._convert_projection_types(doc) for doc in batch]
447+
# Format to output format (tuple)
448+
formatted_batch = [self._format_result(doc) for doc in converted_batch]
446449
self._cached_results.extend(formatted_batch)
447450
self._total_fetched += len(batch)
448451

452+
def _convert_projection_types(self, doc: Dict[str, Any]) -> Dict[str, Any]:
453+
"""Convert string values from SQLite back to proper types based on projection functions.
454+
455+
When projection functions produce datetime/date objects, SQLite stores them as strings.
456+
This method converts them back to proper Python types.
457+
"""
458+
if not self._execution_plan:
459+
return doc
460+
461+
from datetime import datetime, timezone
462+
463+
from .sql.projection_functions import ProjectionFunctionRegistry
464+
465+
converted = dict(doc)
466+
467+
# Get projection information from execution plan
468+
projection_functions = getattr(self._execution_plan, "projection_functions", {})
469+
column_aliases = getattr(self._execution_plan, "column_aliases", {})
470+
471+
if not projection_functions:
472+
return converted
473+
474+
_ = ProjectionFunctionRegistry()
475+
476+
# Iterate through projection functions to convert values
477+
for field_name, func_info in projection_functions.items():
478+
# The column name might be aliased, so check both the original name and the alias
479+
col_names_to_check = [field_name]
480+
if field_name in column_aliases:
481+
col_names_to_check.append(column_aliases[field_name])
482+
# Also check for the mongo_to_bracket_key format
483+
col_names_to_check.append(self._mongo_to_bracket_key(field_name))
484+
485+
col_name = None
486+
for check_name in col_names_to_check:
487+
if check_name in converted:
488+
col_name = check_name
489+
break
490+
491+
if not col_name:
492+
continue
493+
494+
value = converted[col_name]
495+
if value is None:
496+
continue
497+
498+
# Extract function name and format parameter
499+
func_name = None
500+
format_param = None
501+
502+
if isinstance(func_info, dict):
503+
func_name = func_info.get("name")
504+
format_param = func_info.get("format_param")
505+
elif isinstance(func_info, (list, tuple)):
506+
if len(func_info) >= 1:
507+
func_name = func_info[0]
508+
if len(func_info) >= 2:
509+
format_param = func_info[1]
510+
511+
if not func_name:
512+
continue
513+
514+
# Handle special cases for SQLite-stored projection function results
515+
if func_name.upper() in ("DATE", "DATETIME", "TIMESTAMP"):
516+
# These functions should produce datetime objects
517+
try:
518+
if isinstance(value, str):
519+
# Handle BSON Timestamp string representation: "Timestamp(timestamp_int, increment)"
520+
if value.startswith("Timestamp("):
521+
match = re.match(r"Timestamp\((\d+),\s*\d+\)", value)
522+
if match:
523+
timestamp_int = int(match.group(1))
524+
# Convert UNIX timestamp to datetime
525+
converted[col_name] = datetime.fromtimestamp(timestamp_int, tz=timezone.utc).replace(
526+
tzinfo=None
527+
)
528+
continue
529+
530+
# Try to parse as ISO format date/datetime string
531+
try:
532+
# Try full datetime first
533+
converted[col_name] = datetime.fromisoformat(value)
534+
except (ValueError, TypeError):
535+
try:
536+
# Try date only (YYYY-MM-DD)
537+
converted[col_name] = datetime.strptime(value, "%Y-%m-%d")
538+
except (ValueError, TypeError):
539+
try:
540+
# Try with custom format if provided
541+
if format_param:
542+
converted[col_name] = datetime.strptime(value, format_param)
543+
except (ValueError, TypeError):
544+
# Keep original value if all conversions fail
545+
pass
546+
except Exception as e:
547+
_logger.debug(f"Error converting {col_name} using {func_name}: {e}")
548+
# Keep original value if conversion fails
549+
elif func_name.upper() in ("NUMBER", "INT"):
550+
# These functions should produce numeric values
551+
try:
552+
if isinstance(value, str):
553+
# Try to convert to float
554+
converted[col_name] = float(value)
555+
except (ValueError, TypeError):
556+
# Keep original value if conversion fails
557+
pass
558+
elif func_name.upper() == "BOOL":
559+
# BOOL function should produce boolean values
560+
try:
561+
if isinstance(value, str):
562+
converted[col_name] = value.lower() in ("true", "1", "yes")
563+
except Exception:
564+
# Keep original value if conversion fails
565+
pass
566+
567+
return converted
568+
449569

450570
# For backward compatibility
451571
MongoResultSet = ResultSet

pymongosql/superset_mongodb/executor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,16 @@ def execute(
161161
# This tells the cursor to use PreProcessedResultSet instead of regular ResultSet
162162
final_execution_plan.from_intermediate_storage = True
163163

164-
# Extract projection_output from MongoDB result set description
165-
# This provides the correct mapping from outer query column names to type codes
164+
# Extract projection_output and projection_functions from MongoDB execution plan
165+
# This provides the correct mapping from outer query column names to type codes and conversion functions
166+
if hasattr(mongo_execution_plan, "projection_output") and mongo_execution_plan.projection_output:
167+
final_execution_plan.projection_output = mongo_execution_plan.projection_output
168+
169+
# Also copy projection_functions from the inner MongoDB query so PreProcessedResultSet
170+
# can convert string values from SQLite back to proper types (datetime, etc.)
171+
if hasattr(mongo_execution_plan, "projection_functions") and mongo_execution_plan.projection_functions:
172+
final_execution_plan.projection_functions = mongo_execution_plan.projection_functions
173+
166174
if mongo_result_set.description:
167175
projection_output = []
168176
for col_name, type_code, *_ in mongo_result_set.description:
@@ -177,7 +185,9 @@ def execute(
177185

178186
projection_output.append({"output_name": col_name, "function": func_info})
179187

180-
final_execution_plan.projection_output = projection_output
188+
# Only update if projection_output wasn't already set from mongo_execution_plan
189+
if not hasattr(final_execution_plan, "projection_output") or not final_execution_plan.projection_output:
190+
final_execution_plan.projection_output = projection_output
181191

182192
self._execution_plan = final_execution_plan
183193

tests/test_projection_functions.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_convert_none(self):
8787
def test_get_type_code(self):
8888
"""Test type code for date function"""
8989
func = DateFunction()
90-
assert func.get_type_code() == "datetime"
90+
assert func.get_type_code() == datetime
9191

9292

9393
class TestDatetimeFunction:
@@ -167,7 +167,7 @@ def test_convert_none(self):
167167
def test_get_type_code(self):
168168
"""Test type code for datetime function"""
169169
func = DatetimeFunction()
170-
assert func.get_type_code() == "datetime"
170+
assert func.get_type_code() == datetime
171171

172172

173173
class TestTimestampFunction:
@@ -237,7 +237,7 @@ def test_convert_none(self):
237237
def test_get_type_code(self):
238238
"""Test type code for timestamp function"""
239239
func = TimestampFunction()
240-
assert func.get_type_code() == "datetime"
240+
assert func.get_type_code() == datetime
241241

242242

243243
class TestNumberFunction:
@@ -281,7 +281,7 @@ def test_convert_none(self):
281281
def test_get_type_code(self):
282282
"""Test type code for number function"""
283283
func = NumberFunction()
284-
assert func.get_type_code() == "float"
284+
assert func.get_type_code() == float
285285

286286

287287
class TestBoolFunction:
@@ -332,7 +332,7 @@ def test_convert_none(self):
332332
def test_get_type_code(self):
333333
"""Test type code for bool function"""
334334
func = BoolFunction()
335-
assert func.get_type_code() == "bool"
335+
assert func.get_type_code() == bool
336336

337337

338338
class TestSubstrFunction:
@@ -394,7 +394,7 @@ def test_convert_substr_non_string(self):
394394
def test_get_type_code(self):
395395
"""Test type code for substr function"""
396396
func = SubstrFunction()
397-
assert func.get_type_code() == "str"
397+
assert func.get_type_code() == str
398398

399399

400400
class TestReplaceFunction:
@@ -449,7 +449,7 @@ def test_convert_replace_non_string(self):
449449
def test_get_type_code(self):
450450
"""Test type code for replace function"""
451451
func = ReplaceFunction()
452-
assert func.get_type_code() == "str"
452+
assert func.get_type_code() == str
453453

454454

455455
class TestTrimFunction:
@@ -503,7 +503,7 @@ def test_convert_trim_non_string(self):
503503
def test_get_type_code(self):
504504
"""Test type code for trim function"""
505505
func = TrimFunction()
506-
assert func.get_type_code() == "str"
506+
assert func.get_type_code() == str
507507

508508

509509
class TestUpperFunction:
@@ -553,7 +553,7 @@ def test_convert_non_string(self):
553553
def test_get_type_code(self):
554554
"""Test type code for upper function"""
555555
func = UpperFunction()
556-
assert func.get_type_code() == "str"
556+
assert func.get_type_code() == str
557557

558558

559559
class TestLowerFunction:
@@ -603,7 +603,7 @@ def test_convert_non_string(self):
603603
def test_get_type_code(self):
604604
"""Test type code for lower function"""
605605
func = LowerFunction()
606-
assert func.get_type_code() == "str"
606+
assert func.get_type_code() == str
607607

608608

609609
class TestProjectionFunctionRegistry:

tests/test_superset_connection.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -425,14 +425,18 @@ def test_projection_functions_with_superset_execution(self, superset_conn):
425425
# All projection functions return string type codes like 'float', 'datetime', 'date', etc.
426426
import datetime
427427

428-
# NUMBER() should produce 'float' type code
429-
assert type_codes["numeric_age"] == "float", f"Expected 'float' but got {type_codes['numeric_age']!r}"
428+
# NUMBER() should produce float type code
429+
assert type_codes["numeric_age"] == float, f"Expected float but got {type_codes['numeric_age']!r}"
430430

431-
# DATE() should produce 'datetime' type code (note: projection functions return 'datetime' for both DATE and DATETIME)
432-
assert type_codes["creation_date"] == "datetime", f"Expected 'datetime' but got {type_codes['creation_date']!r}"
431+
# DATE() should produce datetime type code (note: projection functions return datetime for both DATE and DATETIME)
432+
assert (
433+
type_codes["creation_date"] == datetime.datetime
434+
), f"Expected datetime but got {type_codes['creation_date']!r}"
433435

434-
# DATETIME() should produce 'datetime' type code
435-
assert type_codes["last_updated"] == "datetime", f"Expected 'datetime' but got {type_codes['last_updated']!r}"
436+
# DATETIME() should produce datetime type code
437+
assert (
438+
type_codes["last_updated"] == datetime.datetime
439+
), f"Expected datetime but got {type_codes['last_updated']!r}"
436440

437441
# Verify data values are correctly converted
438442
for row in rows:
@@ -445,11 +449,11 @@ def test_projection_functions_with_superset_execution(self, superset_conn):
445449
row[numeric_age_idx], (int, float)
446450
), f"numeric_age should be numeric, got {type(row[numeric_age_idx])}"
447451

448-
# DATE() should convert to date object
452+
# DATE() should convert to datetime object
449453
creation_date_val = row[creation_date_idx]
450454
assert creation_date_val is None or isinstance(
451455
creation_date_val, datetime.date
452-
), f"creation_date should be date or None, got {type(creation_date_val)}"
456+
), f"creation_date should be datetime or None, got {type(creation_date_val)}"
453457

454458
# DATETIME() should convert to datetime object
455459
last_updated_val = row[last_updated_idx]
@@ -494,24 +498,22 @@ def test_projection_functions_with_custom_format(self, superset_conn):
494498

495499
import datetime
496500

497-
from bson import Timestamp
498-
499501
# Verify expected type codes
500-
# All projection functions return string type codes
501-
# DATE() should produce 'datetime' type code
502+
# All projection functions return type objects, not strings
503+
# DATE() should produce datetime type code
502504
assert (
503-
type_codes["formatted_date"] == "datetime"
504-
), f"Expected 'datetime' but got {type_codes['formatted_date']!r}"
505+
type_codes["formatted_date"] == datetime.datetime
506+
), f"Expected datetime but got {type_codes['formatted_date']!r}"
505507

506-
# DATETIME() should produce 'datetime' type code
508+
# DATETIME() should produce datetime type code
507509
assert (
508-
type_codes["formatted_datetime"] == "datetime"
509-
), f"Expected 'datetime' but got {type_codes['formatted_datetime']!r}"
510+
type_codes["formatted_datetime"] == datetime.datetime
511+
), f"Expected datetime but got {type_codes['formatted_datetime']!r}"
510512

511-
# TIMESTAMP() should produce 'datetime' type code
513+
# TIMESTAMP() should produce datetime type code
512514
assert (
513-
type_codes["timestamp_value"] == "datetime"
514-
), f"Expected 'datetime' but got {type_codes['timestamp_value']!r}"
515+
type_codes["timestamp_value"] == datetime.datetime
516+
), f"Expected datetime but got {type_codes['timestamp_value']!r}"
515517

516518
# Verify data values are correctly converted with custom formats
517519
for row in rows:
@@ -528,14 +530,14 @@ def test_projection_functions_with_custom_format(self, superset_conn):
528530
# DATETIME() with format should convert to datetime object
529531
datetime_value = row[formatted_datetime_idx]
530532
assert datetime_value is None or isinstance(
531-
datetime_value, datetime.datetime
533+
datetime_value, datetime.date
532534
), f"formatted_datetime should be datetime or None, got {type(datetime_value)}"
533535

534-
# TIMESTAMP() with format should convert to Timestamp object
536+
# TIMESTAMP() with format should convert to datetime object
535537
timestamp_value = row[timestamp_value_idx]
536538
assert timestamp_value is None or isinstance(
537-
timestamp_value, Timestamp
538-
), f"timestamp_value should be Timestamp or None, got {type(timestamp_value)}"
539+
timestamp_value, datetime.datetime
540+
), f"timestamp_value should be datetime or None, got {type(timestamp_value)}"
539541

540542
def test_empty_result_with_valid_description(self, superset_conn):
541543
"""Test that description is available for result sets, even if empty after filtering"""

0 commit comments

Comments
 (0)