Skip to content

Commit 73b1acc

Browse files
author
Peng Ren
committed
Fix the issue if collection name wrapped by double quotes for aggregate function
1 parent 8903fde commit 73b1acc

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

pymongosql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if TYPE_CHECKING:
77
from .connection import Connection
88

9-
__version__: str = "0.4.2"
9+
__version__: str = "0.4.3"
1010

1111
# Globals https://www.python.org/dev/peps/pep-0249/#globals
1212
apilevel: str = "2.0"

pymongosql/sql/query_handler.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ def can_handle(self, ctx: Any) -> bool:
168168
"""Check if this is a from context"""
169169
return hasattr(ctx, "tableReference")
170170

171+
@staticmethod
172+
def _strip_collection_quotes(name: str) -> str:
173+
"""Strip surrounding double quotes from collection name if present.
174+
175+
Args:
176+
name: Collection name, potentially quoted
177+
178+
Returns:
179+
Collection name with quotes removed
180+
"""
181+
return re.sub(r'^"([^"]+)"$', r"\1", name)
182+
171183
def _parse_function_call(self, ctx: Any) -> Optional[Dict[str, Any]]:
172184
"""
173185
Detect and parse aggregate() function calls in FROM clause.
@@ -196,13 +208,17 @@ def _parse_function_call(self, ctx: Any) -> Optional[Dict[str, Any]]:
196208

197209
# Pattern: [qualifier.]functionName(arg1, arg2)
198210
# We need to match: (optional_collection.)aggregate('...', '...')
199-
pattern = r"^(?:(\w+)\.)?aggregate\s*\(\s*'([^']*)'\s*,\s*'([^']*)'\s*\)$"
211+
# Support collection names with double quotes for special characters like hyphens
212+
pattern = r"^(?:(\"[^\"]+\"|\w+)\.)?aggregate\s*\(\s*'([^']*)'\s*,\s*'([^']*)'\s*\)$"
200213
match = re.match(pattern, text, re.IGNORECASE | re.DOTALL)
201214

202215
if not match:
203216
return None
204217

205218
collection = match.group(1) # Can be None for unqualified aggregate()
219+
# Strip quotes from collection name if present
220+
if collection:
221+
collection = self._strip_collection_quotes(collection)
206222
pipeline = match.group(2)
207223
options = match.group(3)
208224

@@ -245,7 +261,7 @@ def handle_visitor(self, ctx: PartiQLParser.FromClauseContext, parse_result: "Qu
245261
# Regular collection reference
246262
table_text = ctx.tableReference().getText()
247263
# Strip surrounding quotes from collection name (e.g., "user.accounts" -> user.accounts)
248-
collection_name = re.sub(r'^"([^"]+)"$', r"\1", table_text)
264+
collection_name = self._strip_collection_quotes(table_text)
249265
parse_result.collection = collection_name
250266
_logger.debug(f"Parsed regular collection: {collection_name}")
251267
return collection_name

tests/test_cursor_aggregate.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,30 @@ def test_aggregate_multiple_stages(self, conn):
327327
total_users_idx = col_names.index("total_users")
328328
assert row[avg_age_idx] is not None and isinstance(row[avg_age_idx], (int, float))
329329
assert row[total_users_idx] is not None and isinstance(row[total_users_idx], (int, float))
330+
331+
def test_aggregate_collection_name_with_hyphen(self, conn):
332+
"""Test aggregate function with collection name containing hyphen (user-orders)"""
333+
pipeline = json.dumps([{"$match": {"customer_type": "premium"}}])
334+
335+
# Test collection name with hyphen
336+
sql = f"""
337+
SELECT *
338+
FROM "user-orders".aggregate('{pipeline}', '{{}}')
339+
"""
340+
341+
cursor = conn.cursor()
342+
result = cursor.execute(sql)
343+
344+
assert result == cursor
345+
assert isinstance(cursor.result_set, ResultSet)
346+
347+
rows = cursor.result_set.fetchall()
348+
assert len(rows) > 0, "Should have results from user-orders collection"
349+
350+
# Verify all returned rows are premium customers
351+
col_names = [desc[0] for desc in cursor.result_set.description]
352+
assert "customer_type" in col_names, "customer_type should be in result columns"
353+
354+
customer_type_idx = col_names.index("customer_type")
355+
for row in rows:
356+
assert row[customer_type_idx] == "premium", "All rows should have customer_type='premium'"

0 commit comments

Comments
 (0)