11import datetime
22from decimal import Decimal
3+ from functools import partialmethod
34from uuid import UUID
45
56from bson import Decimal128
2930from django_mongodb_backend .query_utils import process_lhs
3031
3132
32- def base_expression (self , compiler , connection , as_path = False , ** extra ):
33- if as_path and hasattr (self , "as_mql_path" ) and getattr (self , "can_use_path" , False ):
33+ def base_expression (self , compiler , connection , as_expr = False , ** extra ):
34+ if not as_expr and hasattr (self , "as_mql_path" ) and getattr (self , "can_use_path" , False ):
3435 return self .as_mql_path (compiler , connection , ** extra )
3536
3637 expr = self .as_mql_expr (compiler , connection , ** extra )
37- return {"$expr" : expr } if as_path else expr
38+ return expr if as_expr else {"$expr" : expr }
3839
3940
4041def case (self , compiler , connection ):
4142 case_parts = []
4243 for case in self .cases :
4344 case_mql = {}
4445 try :
45- case_mql ["case" ] = case .as_mql (compiler , connection )
46+ case_mql ["case" ] = case .as_mql (compiler , connection , as_expr = True )
4647 except EmptyResultSet :
4748 continue
4849 except FullResultSet :
49- default_mql = case .result .as_mql (compiler , connection )
50+ default_mql = case .result .as_mql (compiler , connection , as_expr = True )
5051 break
51- case_mql ["then" ] = case .result .as_mql (compiler , connection )
52+ case_mql ["then" ] = case .result .as_mql (compiler , connection , as_expr = True )
5253 case_parts .append (case_mql )
5354 else :
54- default_mql = self .default .as_mql (compiler , connection )
55+ default_mql = self .default .as_mql (compiler , connection , as_expr = True )
5556 if not case_parts :
5657 return default_mql
5758 return {
@@ -62,7 +63,7 @@ def case(self, compiler, connection):
6263 }
6364
6465
65- def col (self , compiler , connection , as_path = False ): # noqa: ARG001
66+ def col (self , compiler , connection , as_expr = False ): # noqa: ARG001
6667 # If the column is part of a subquery and belongs to one of the parent
6768 # queries, it will be stored for reference using $let in a $lookup stage.
6869 # If the query is built with `alias_cols=False`, treat the column as
@@ -80,39 +81,39 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
8081 # Add the column's collection's alias for columns in joined collections.
8182 has_alias = self .alias and self .alias != compiler .collection_name
8283 prefix = f"{ self .alias } ." if has_alias else ""
83- if not as_path :
84+ if as_expr :
8485 prefix = f"${ prefix } "
8586 return f"{ prefix } { self .target .column } "
8687
8788
88- def col_pairs (self , compiler , connection , as_path = False ):
89+ def col_pairs (self , compiler , connection , as_expr = False ):
8990 cols = self .get_cols ()
9091 if len (cols ) > 1 :
9192 raise NotSupportedError ("ColPairs is not supported." )
92- return cols [0 ].as_mql (compiler , connection , as_path = as_path )
93+ return cols [0 ].as_mql (compiler , connection , as_expr = as_expr )
9394
9495
9596def combined_expression (self , compiler , connection ):
9697 expressions = [
97- self .lhs .as_mql (compiler , connection ),
98- self .rhs .as_mql (compiler , connection ),
98+ self .lhs .as_mql (compiler , connection , as_expr = True ),
99+ self .rhs .as_mql (compiler , connection , as_expr = True ),
99100 ]
100101 return connection .ops .combine_expression (self .connector , expressions )
101102
102103
103104def expression_wrapper (self , compiler , connection ):
104- return self .expression .as_mql (compiler , connection )
105+ return self .expression .as_mql (compiler , connection , as_expr = True )
105106
106107
107108def negated_expression (self , compiler , connection ):
108109 return {"$not" : expression_wrapper (self , compiler , connection )}
109110
110111
111112def order_by (self , compiler , connection ):
112- return self .expression .as_mql (compiler , connection )
113+ return self .expression .as_mql (compiler , connection , as_expr = True )
113114
114115
115- def query (self , compiler , connection , get_wrapping_pipeline = None , as_path = False ):
116+ def query (self , compiler , connection , get_wrapping_pipeline = None , as_expr = False ):
116117 subquery_compiler = self .get_compiler (connection = connection )
117118 subquery_compiler .pre_sql_setup (with_col_aliases = False )
118119 field_name , expr = subquery_compiler .columns [0 ]
@@ -132,7 +133,7 @@ def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False)
132133 "as" : table_output ,
133134 "from" : from_table ,
134135 "let" : {
135- compiler .PARENT_FIELD_TEMPLATE .format (i ): col .as_mql (compiler , connection )
136+ compiler .PARENT_FIELD_TEMPLATE .format (i ): col .as_mql (compiler , connection , as_expr = True )
136137 for col , i in subquery_compiler .column_indices .items ()
137138 },
138139 }
@@ -154,16 +155,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False)
154155 # Erase project_fields since the required value is projected above.
155156 subquery .project_fields = None
156157 compiler .subqueries .append (subquery )
157- if as_path :
158- return f"{ table_output } .{ field_name } "
159- return f"$ { table_output } .{ field_name } "
158+ if as_expr :
159+ return f"$ { table_output } .{ field_name } "
160+ return f"{ table_output } .{ field_name } "
160161
161162
162163def raw_sql (self , compiler , connection ): # noqa: ARG001
163164 raise NotSupportedError ("RawSQL is not supported on MongoDB." )
164165
165166
166- def ref (self , compiler , connection , as_path = False ): # noqa: ARG001
167+ def ref (self , compiler , connection , as_expr = False ): # noqa: ARG001
167168 prefix = (
168169 f"{ self .source .alias } ."
169170 if isinstance (self .source , Col ) and self .source .alias != compiler .collection_name
@@ -173,7 +174,7 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001
173174 refs , _ = compiler .columns [self .ordinal - 1 ]
174175 else :
175176 refs = self .refs
176- if not as_path :
177+ if as_expr :
177178 prefix = f"${ prefix } "
178179 return f"{ prefix } { refs } "
179180
@@ -187,27 +188,29 @@ def star(self, compiler, connection): # noqa: ARG001
187188 return {"$literal" : True }
188189
189190
190- def subquery (self , compiler , connection , get_wrapping_pipeline = None ):
191+ def subquery (self , compiler , connection , get_wrapping_pipeline = None , as_expr = False ):
191192 return self .query .as_mql (
192- compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline , as_path = False
193+ compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline , as_expr = as_expr
193194 )
194195
195196
196197def exists (self , compiler , connection , get_wrapping_pipeline = None ):
197198 try :
198- lhs_mql = subquery (self , compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline )
199+ lhs_mql = subquery (
200+ self , compiler , connection , get_wrapping_pipeline = get_wrapping_pipeline , as_expr = True
201+ )
199202 except EmptyResultSet :
200- return Value (False ).as_mql (compiler , connection )
203+ return Value (False ).as_mql (compiler , connection , as_expr = True )
201204 return connection .mongo_expr_operators ["isnull" ](lhs_mql , False )
202205
203206
204207def when (self , compiler , connection ):
205- return self .condition .as_mql (compiler , connection )
208+ return self .condition .as_mql (compiler , connection , as_expr = True )
206209
207210
208- def value (self , compiler , connection , as_path = False ): # noqa: ARG001
211+ def value (self , compiler , connection , as_expr = False ): # noqa: ARG001
209212 value = self .value
210- if isinstance (value , (list , int )) and not as_path :
213+ if isinstance (value , (list , int )) and as_expr :
211214 # Wrap lists & numbers in $literal to prevent ambiguity when Value
212215 # appears in $project.
213216 return {"$literal" : value }
@@ -248,6 +251,7 @@ def register_expressions():
248251 Ref .is_simple_column = ref_is_simple_column
249252 ResolvedOuterRef .as_mql = ResolvedOuterRef .as_sql
250253 Star .as_mql_expr = star
251- Subquery .as_mql_expr = subquery
254+ Subquery .as_mql_expr = partialmethod (subquery , as_expr = True )
255+ Subquery .as_mql_path = partialmethod (subquery , as_expr = False )
252256 When .as_mql_expr = when
253257 Value .as_mql = value
0 commit comments