|
4 | 4 | from django.core.exceptions import EmptyResultSet, FullResultSet |
5 | 5 | from django.db import DatabaseError, IntegrityError, NotSupportedError |
6 | 6 | from django.db.models.expressions import Case, Col, When |
| 7 | +from django.db.models.fields.related import ForeignKey |
7 | 8 | from django.db.models.functions import Mod |
8 | 9 | from django.db.models.lookups import Exact |
9 | 10 | from django.db.models.sql.constants import INNER |
@@ -181,14 +182,23 @@ def _get_reroot_replacements(expression): |
181 | 182 | lhs_fields = [] |
182 | 183 | rhs_fields = [] |
183 | 184 | # Add a join condition for each pair of joining fields. |
| 185 | + local_field = foreign_field = None |
184 | 186 | for lhs, rhs in self.join_fields: |
185 | | - lhs, rhs = connection.ops.prepare_join_on_clause( |
| 187 | + lhs_prepared, rhs_prepared = connection.ops.prepare_join_on_clause( |
186 | 188 | self.parent_alias, lhs, compiler.collection_name, rhs |
187 | 189 | ) |
188 | | - lhs_fields.append(lhs.as_mql(compiler, connection, as_expr=True)) |
189 | | - # In the lookup stage, the reference to this column doesn't include the |
190 | | - # collection name. |
191 | | - rhs_fields.append(rhs.as_mql(compiler, connection, as_expr=True)) |
| 190 | + if ( |
| 191 | + isinstance(lhs, ForeignKey) |
| 192 | + and isinstance(lhs_prepared, Col) |
| 193 | + and isinstance(rhs_prepared, Col) |
| 194 | + ): |
| 195 | + local_field = lhs_prepared.as_mql(compiler, connection) |
| 196 | + foreign_field = rhs_prepared.as_mql(compiler, connection) |
| 197 | + else: |
| 198 | + lhs_fields.append(lhs_prepared.as_mql(compiler, connection, as_expr=True)) |
| 199 | + # In the lookup stage, the reference to this column doesn't include the |
| 200 | + # collection name. |
| 201 | + rhs_fields.append(rhs_prepared.as_mql(compiler, connection, as_expr=True)) |
192 | 202 | # Handle any join conditions besides matching field pairs. |
193 | 203 | extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) |
194 | 204 | extra_conditions = [] |
@@ -218,32 +228,47 @@ def _get_reroot_replacements(expression): |
218 | 228 | # self.table_name.field2 = parent_table.field2 |
219 | 229 | # AND |
220 | 230 | # ... |
221 | | - condition = { |
222 | | - "$expr": { |
223 | | - "$and": [ |
224 | | - {"$eq": [f"$${parent_template}{i}", field]} for i, field in enumerate(rhs_fields) |
225 | | - ] |
226 | | - } |
227 | | - } |
| 231 | + all_conditions = [] |
| 232 | + if rhs_fields: |
| 233 | + all_conditions.append( |
| 234 | + { |
| 235 | + "$expr": { |
| 236 | + "$and": [ |
| 237 | + {"$eq": [f"$${parent_template}{i}", field]} |
| 238 | + for i, field in enumerate(rhs_fields) |
| 239 | + ] |
| 240 | + } |
| 241 | + } |
| 242 | + ) |
228 | 243 | if extra_conditions: |
229 | | - condition = {"$and": [condition, *extra_conditions]} |
230 | | - lookup_pipeline = [ |
231 | | - { |
232 | | - "$lookup": { |
233 | | - # The right-hand table to join. |
234 | | - "from": self.table_name, |
235 | | - # The pipeline variables to be matched in the pipeline's |
236 | | - # expression. |
237 | | - "let": { |
238 | | - f"{parent_template}{i}": parent_field |
239 | | - for i, parent_field in enumerate(lhs_fields) |
240 | | - }, |
241 | | - "pipeline": [{"$match": condition}], |
242 | | - # Rename the output as table_alias. |
243 | | - "as": self.table_alias, |
| 244 | + all_conditions.extend(extra_conditions) |
| 245 | + # Build matching pipeline |
| 246 | + if len(all_conditions) == 0: |
| 247 | + pipeline = [] |
| 248 | + elif len(all_conditions) == 1: |
| 249 | + pipeline = [{"$match": all_conditions[0]}] |
| 250 | + else: |
| 251 | + pipeline = [{"$match": {"$and": all_conditions}}] |
| 252 | + |
| 253 | + lookup = { |
| 254 | + # The right-hand table to join. |
| 255 | + "from": self.table_name, |
| 256 | + "pipeline": pipeline, |
| 257 | + # Rename the output as table_alias. |
| 258 | + "as": self.table_alias, |
| 259 | + } |
| 260 | + if local_field and foreign_field: |
| 261 | + lookup.update( |
| 262 | + { |
| 263 | + "localField": local_field, |
| 264 | + "foreignField": foreign_field, |
244 | 265 | } |
245 | | - }, |
246 | | - ] |
| 266 | + ) |
| 267 | + if lhs_fields: |
| 268 | + lookup["let"] = { |
| 269 | + f"{parent_template}{i}": parent_field for i, parent_field in enumerate(lhs_fields) |
| 270 | + } |
| 271 | + lookup_pipeline = [{"$lookup": lookup}] |
247 | 272 | # To avoid missing data when using $unwind, an empty collection is added if |
248 | 273 | # the join isn't an inner join. For inner joins, rows with empty arrays are |
249 | 274 | # removed, as $unwind unrolls or unnests the array and removes the row if |
|
0 commit comments