Skip to content

Commit 503233d

Browse files
authored
Add aggregate function support (#13)
1 parent da087fd commit 503233d

File tree

10 files changed

+699
-9
lines changed

10 files changed

+699
-9
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ global-exclude *.pyc
1414
global-exclude *.pyo
1515
global-exclude __pycache__
1616
global-exclude .git*
17+
global-exclude .github
1718
global-exclude .pytest_cache
1819
global-exclude .coverage
1920
global-exclude htmlcov

README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to Mo
2626
- **DB API 2.0 Compliant**: Full compatibility with Python Database API 2.0 specification
2727
- **PartiQL-based SQL Syntax**: Built on [PartiQL](https://partiql.org/tutorial.html) (SQL for semi-structured data), enabling seamless SQL querying of nested and hierarchical MongoDB documents
2828
- **Nested Structure Support**: Query and filter deeply nested fields and arrays within MongoDB documents using standard SQL syntax
29+
- **MongoDB Aggregate Pipeline Support**: Execute native MongoDB aggregation pipelines using SQL-like syntax with `aggregate()` function
2930
- **SQLAlchemy Integration**: Complete ORM and Core support with dedicated MongoDB dialect
3031
- **SQL Query Support**: SELECT statements with WHERE conditions, field selection, and aliases
3132
- **DML Support**: Full support for INSERT, UPDATE, and DELETE operations using PartiQL syntax
@@ -80,6 +81,7 @@ pip install -e .
8081
- [WHERE Clauses](#where-clauses)
8182
- [Nested Field Support](#nested-field-support)
8283
- [Sorting and Limiting](#sorting-and-limiting)
84+
- [MongoDB Aggregate Function](#mongodb-aggregate-function)
8385
- [INSERT Statements](#insert-statements)
8486
- [UPDATE Statements](#update-statements)
8587
- [DELETE Statements](#delete-statements)
@@ -235,6 +237,61 @@ Parameters are substituted into the MongoDB filter during execution, providing p
235237
- **LIMIT**: `LIMIT 10`
236238
- **Combined**: `ORDER BY created_at DESC LIMIT 5`
237239

240+
### MongoDB Aggregate Function
241+
242+
PyMongoSQL supports executing native MongoDB aggregation pipelines using SQL-like syntax with the `aggregate()` function. This allows you to leverage MongoDB's powerful aggregation framework while maintaining SQL-style query patterns.
243+
244+
**Syntax**
245+
246+
The `aggregate()` function accepts two parameters:
247+
- **pipeline**: JSON string representing the MongoDB aggregation pipeline
248+
- **options**: JSON string for aggregation options (optional, use '{}' for defaults)
249+
250+
**Qualified Aggregate (Collection-Specific)**
251+
252+
```python
253+
cursor.execute(
254+
"SELECT * FROM users.aggregate('[{\"$match\": {\"age\": {\"$gt\": 25}}}, {\"$group\": {\"_id\": \"$city\", \"count\": {\"$sum\": 1}}}]', '{}')"
255+
)
256+
results = cursor.fetchall()
257+
```
258+
259+
**Unqualified Aggregate (Database-Level)**
260+
261+
```python
262+
cursor.execute(
263+
"SELECT * FROM aggregate('[{\"$match\": {\"status\": \"active\"}}]', '{\"allowDiskUse\": true}')"
264+
)
265+
results = cursor.fetchall()
266+
```
267+
268+
**Post-Aggregation Filtering and Sorting**
269+
270+
You can apply WHERE, ORDER BY, and LIMIT clauses after aggregation:
271+
272+
```python
273+
# Filter aggregation results
274+
cursor.execute(
275+
"SELECT * FROM users.aggregate('[{\"$group\": {\"_id\": \"$city\", \"total\": {\"$sum\": 1}}}]', '{}') WHERE total > 100"
276+
)
277+
278+
# Sort and limit aggregation results
279+
cursor.execute(
280+
"SELECT * FROM products.aggregate('[{\"$match\": {\"category\": \"Electronics\"}}]', '{}') ORDER BY price DESC LIMIT 10"
281+
)
282+
```
283+
284+
**Projection Support**
285+
286+
```python
287+
# Select specific fields from aggregation results
288+
cursor.execute(
289+
"SELECT _id, total FROM users.aggregate('[{\"$group\": {\"_id\": \"$city\", \"total\": {\"$sum\": 1}}}]', '{}')"
290+
)
291+
```
292+
293+
**Note**: The pipeline and options must be valid JSON strings enclosed in single quotes. Post-aggregation filtering (WHERE), sorting (ORDER BY), and limiting (LIMIT) are applied in Python after the aggregation executes on MongoDB.
294+
238295
### INSERT Statements
239296

240297
PyMongoSQL supports inserting documents into MongoDB collections using both PartiQL-style object literals and standard SQL INSERT VALUES syntax.

pymongosql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if TYPE_CHECKING:
77
from .connection import Connection
88

9-
__version__: str = "0.3.3"
9+
__version__: str = "0.3.4"
1010

1111
# Globals https://www.python.org/dev/peps/pep-0249/#globals
1212
apilevel: str = "2.0"

pymongosql/executor.py

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _replace_placeholders(self, obj: Any, parameters: Sequence[Any]) -> Any:
9898
"""Recursively replace ? placeholders with parameter values in filter/projection dicts"""
9999
return SQLHelper.replace_placeholders_generic(obj, parameters, "qmark")
100100

101-
def _execute_execution_plan(
101+
def _execute_find_plan(
102102
self,
103103
execution_plan: QueryExecutionPlan,
104104
connection: Any = None,
@@ -172,6 +172,163 @@ def _execute_execution_plan(
172172
_logger.error(f"Unexpected error during command execution: {e}")
173173
raise OperationalError(f"Command execution error: {e}")
174174

175+
def _execute_aggregate_plan(
176+
self,
177+
execution_plan: QueryExecutionPlan,
178+
connection: Any = None,
179+
parameters: Optional[Sequence[Any]] = None,
180+
) -> Optional[Dict[str, Any]]:
181+
"""Execute a QueryExecutionPlan with aggregate() call.
182+
183+
Args:
184+
execution_plan: QueryExecutionPlan with aggregate_pipeline and aggregate_options
185+
connection: Connection object (for database access)
186+
parameters: Parameters for placeholder replacement
187+
188+
Returns:
189+
Command result with aggregation results
190+
"""
191+
try:
192+
import json
193+
194+
# Get database from connection
195+
if not connection:
196+
raise OperationalError("No connection provided")
197+
198+
db = connection.database
199+
200+
if not execution_plan.collection:
201+
raise ProgrammingError("No collection specified in aggregate query")
202+
203+
# Parse pipeline and options from JSON strings
204+
try:
205+
pipeline = json.loads(execution_plan.aggregate_pipeline or "[]")
206+
options = json.loads(execution_plan.aggregate_options or "{}")
207+
except json.JSONDecodeError as e:
208+
raise ProgrammingError(f"Invalid JSON in aggregate pipeline or options: {e}")
209+
210+
_logger.debug(f"Executing aggregate on collection {execution_plan.collection}")
211+
_logger.debug(f"Pipeline: {pipeline}")
212+
_logger.debug(f"Options: {options}")
213+
214+
# Get collection and call aggregate()
215+
collection = db[execution_plan.collection]
216+
217+
# Execute aggregate with options
218+
cursor = collection.aggregate(pipeline, **options)
219+
220+
# Convert cursor to list
221+
results = list(cursor)
222+
223+
# Apply additional filters if specified (from WHERE clause)
224+
if execution_plan.filter_stage:
225+
_logger.debug(f"Applying additional filter: {execution_plan.filter_stage}")
226+
# Would need to filter results in Python, as aggregate already ran
227+
# For now, log that we're applying filters
228+
results = self._filter_results(results, execution_plan.filter_stage)
229+
230+
# Apply sorting if specified
231+
if execution_plan.sort_stage:
232+
for sort_dict in reversed(execution_plan.sort_stage):
233+
for field_name, direction in sort_dict.items():
234+
reverse = direction == -1
235+
results = sorted(results, key=lambda x: x.get(field_name), reverse=reverse)
236+
237+
# Apply skip and limit
238+
if execution_plan.skip_stage:
239+
results = results[execution_plan.skip_stage :]
240+
241+
if execution_plan.limit_stage:
242+
results = results[: execution_plan.limit_stage]
243+
244+
# Apply projection if specified
245+
if execution_plan.projection_stage:
246+
results = self._apply_projection(results, execution_plan.projection_stage)
247+
248+
# Return in command result format
249+
return {
250+
"cursor": {"firstBatch": results},
251+
"ok": 1,
252+
}
253+
254+
except (ProgrammingError, OperationalError):
255+
raise
256+
except PyMongoError as e:
257+
_logger.error(f"MongoDB aggregate execution failed: {e}")
258+
raise DatabaseError(f"Aggregate execution failed: {e}")
259+
except Exception as e:
260+
_logger.error(f"Unexpected error during aggregate execution: {e}")
261+
raise OperationalError(f"Aggregate execution error: {e}")
262+
263+
@staticmethod
264+
def _filter_results(results: list, filter_conditions: dict) -> list:
265+
"""Apply MongoDB filter conditions to Python results"""
266+
# Basic filtering implementation
267+
# This is a simplified version - can be enhanced with full MongoDB query operators
268+
filtered = []
269+
for doc in results:
270+
if StandardQueryExecution._matches_filter(doc, filter_conditions):
271+
filtered.append(doc)
272+
return filtered
273+
274+
@staticmethod
275+
def _matches_filter(doc: dict, filter_conditions: dict) -> bool:
276+
"""Check if a document matches the filter conditions"""
277+
for field, condition in filter_conditions.items():
278+
if field == "$and":
279+
return all(StandardQueryExecution._matches_filter(doc, cond) for cond in condition)
280+
elif field == "$or":
281+
return any(StandardQueryExecution._matches_filter(doc, cond) for cond in condition)
282+
elif isinstance(condition, dict):
283+
# Handle operators like $eq, $gt, etc.
284+
for op, value in condition.items():
285+
if op == "$eq":
286+
if doc.get(field) != value:
287+
return False
288+
elif op == "$ne":
289+
if doc.get(field) == value:
290+
return False
291+
elif op == "$gt":
292+
if not (doc.get(field) > value):
293+
return False
294+
elif op == "$gte":
295+
if not (doc.get(field) >= value):
296+
return False
297+
elif op == "$lt":
298+
if not (doc.get(field) < value):
299+
return False
300+
elif op == "$lte":
301+
if not (doc.get(field) <= value):
302+
return False
303+
else:
304+
if doc.get(field) != condition:
305+
return False
306+
return True
307+
308+
@staticmethod
309+
def _apply_projection(results: list, projection_stage: dict) -> list:
310+
"""Apply projection to results"""
311+
projected = []
312+
include_fields = {k for k, v in projection_stage.items() if v == 1}
313+
exclude_fields = {k for k, v in projection_stage.items() if v == 0}
314+
315+
for doc in results:
316+
if include_fields:
317+
# Include mode: only include specified fields
318+
projected_doc = (
319+
{"_id": doc.get("_id")} if "_id" in include_fields or "_id" not in projection_stage else {}
320+
)
321+
for field in include_fields:
322+
if field != "_id" and field in doc:
323+
projected_doc[field] = doc[field]
324+
projected.append(projected_doc)
325+
else:
326+
# Exclude mode: exclude specified fields
327+
projected_doc = {k: v for k, v in doc.items() if k not in exclude_fields}
328+
projected.append(projected_doc)
329+
330+
return projected
331+
175332
def execute(
176333
self,
177334
context: ExecutionContext,
@@ -197,7 +354,11 @@ def execute(
197354
# Parse the query
198355
self._execution_plan = self._parse_sql(processed_query)
199356

200-
return self._execute_execution_plan(self._execution_plan, connection, processed_params)
357+
# Route to appropriate execution plan handler
358+
if hasattr(self._execution_plan, "is_aggregate_query") and self._execution_plan.is_aggregate_query:
359+
return self._execute_aggregate_plan(self._execution_plan, connection, processed_params)
360+
else:
361+
return self._execute_find_plan(self._execution_plan, connection, processed_params)
201362

202363

203364
class InsertExecution(ExecutionStrategy):

pymongosql/sql/builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,15 @@ def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
118118
parse_result.column_aliases
119119
).sort(parse_result.sort_fields).limit(parse_result.limit_value).skip(parse_result.offset_value)
120120

121-
return builder.build()
121+
# Set aggregate flags BEFORE building (needed for validation)
122+
if hasattr(parse_result, "is_aggregate_query") and parse_result.is_aggregate_query:
123+
builder._execution_plan.is_aggregate_query = True
124+
builder._execution_plan.aggregate_pipeline = parse_result.aggregate_pipeline
125+
builder._execution_plan.aggregate_options = parse_result.aggregate_options
126+
127+
# Now build and validate
128+
plan = builder.build()
129+
return plan
122130

123131
@staticmethod
124132
def _build_insert_plan(parse_result: "InsertParseResult") -> "InsertExecutionPlan":

pymongosql/sql/query_builder.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@ class QueryExecutionPlan(ExecutionPlan):
1818
sort_stage: List[Dict[str, int]] = field(default_factory=list)
1919
limit_stage: Optional[int] = None
2020
skip_stage: Optional[int] = None
21+
# Aggregate pipeline support
22+
aggregate_pipeline: Optional[str] = None # JSON string representation of pipeline
23+
aggregate_options: Optional[str] = None # JSON string representation of options
24+
is_aggregate_query: bool = False # Flag indicating this is an aggregate() call
2125

2226
def to_dict(self) -> Dict[str, Any]:
2327
"""Convert query plan to dictionary representation"""
24-
return {
28+
result = {
2529
"collection": self.collection,
2630
"filter": self.filter_stage,
2731
"projection": self.projection_stage,
@@ -30,9 +34,22 @@ def to_dict(self) -> Dict[str, Any]:
3034
"skip": self.skip_stage,
3135
}
3236

37+
# Add aggregate-specific fields if present
38+
if self.is_aggregate_query:
39+
result["is_aggregate_query"] = True
40+
result["aggregate_pipeline"] = self.aggregate_pipeline
41+
result["aggregate_options"] = self.aggregate_options
42+
43+
return result
44+
3345
def validate(self) -> bool:
3446
"""Validate the query plan"""
35-
errors = self.validate_base()
47+
# For aggregate queries, collection is optional (unqualified aggregate syntax)
48+
# For regular queries, collection is required
49+
if self.is_aggregate_query:
50+
errors = []
51+
else:
52+
errors = self.validate_base()
3653

3754
if self.limit_stage is not None and (not isinstance(self.limit_stage, int) or self.limit_stage < 0):
3855
errors.append("Limit must be a non-negative integer")
@@ -56,6 +73,9 @@ def copy(self) -> "QueryExecutionPlan":
5673
sort_stage=self.sort_stage.copy(),
5774
limit_stage=self.limit_stage,
5875
skip_stage=self.skip_stage,
76+
aggregate_pipeline=self.aggregate_pipeline,
77+
aggregate_options=self.aggregate_options,
78+
is_aggregate_query=self.is_aggregate_query,
5979
)
6080

6181

@@ -217,7 +237,9 @@ def validate(self) -> bool:
217237
"""Validate the current query plan"""
218238
self._validation_errors.clear()
219239

220-
if not self._execution_plan.collection:
240+
# For aggregate queries, collection is optional (unqualified aggregate syntax)
241+
# For regular queries, collection is required
242+
if not self._execution_plan.is_aggregate_query and not self._execution_plan.collection:
221243
self._add_error("Collection name is required")
222244

223245
# Add more validation rules as needed

0 commit comments

Comments
 (0)