|
12 | 12 | from django.db.models.sql.constants import INNER, GET_ITERATOR_CHUNK_SIZE |
13 | 13 | from django.db.models.sql.datastructures import Join |
14 | 14 | from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode |
| 15 | +from django.db.models.sql import Query |
15 | 16 | from django.utils.functional import cached_property |
16 | 17 | from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError |
17 | 18 |
|
@@ -314,280 +315,12 @@ def register_nodes(): |
314 | 315 |
|
315 | 316 |
|
316 | 317 | class MongoQuerySet(QuerySet): |
317 | | - def raw_mql(self, raw_query, params=(), translations=None, using=None): |
318 | | - if using is None: |
319 | | - using = self.db |
320 | | - qs = RawQuerySet( |
321 | | - raw_query, |
322 | | - model=self.model, |
323 | | - params=params, |
324 | | - translations=translations, |
325 | | - using=using, |
326 | | - ) |
327 | | - return qs |
| 318 | + def raw_mql(self, raw_query): |
| 319 | + return QuerySet(self.model, RawQuery(self.model, raw_query)) |
328 | 320 |
|
329 | 321 |
|
330 | | -class RawQuerySet: |
331 | | - """ |
332 | | - Provide an iterator which converts the results of raw SQL queries into |
333 | | - annotated model instances. |
334 | | - """ |
| 322 | +class RawQuery(Query): |
335 | 323 |
|
336 | | - def __init__( |
337 | | - self, |
338 | | - raw_query, |
339 | | - model=None, |
340 | | - query=None, |
341 | | - params=(), |
342 | | - translations=None, |
343 | | - using=None, |
344 | | - hints=None, |
345 | | - ): |
| 324 | + def __init__(self, model, raw_query): |
| 325 | + super(RawQuery, self).__init__(model) |
346 | 326 | self.raw_query = raw_query |
347 | | - self.model = model |
348 | | - self._db = using |
349 | | - self._hints = hints or {} |
350 | | - self.query = query or RawQuery(sql=raw_query, using=self.db, params=params) |
351 | | - self.params = params |
352 | | - self.translations = translations or {} |
353 | | - self._result_cache = None |
354 | | - self._prefetch_related_lookups = () |
355 | | - self._prefetch_done = False |
356 | | - |
357 | | - def resolve_model_init_order(self): |
358 | | - """Resolve the init field names and value positions.""" |
359 | | - converter = connections[self.db].introspection.identifier_converter |
360 | | - model_init_fields = [ |
361 | | - f for f in self.model._meta.fields if converter(f.column) in self.columns |
362 | | - ] |
363 | | - annotation_fields = [ |
364 | | - (column, pos) |
365 | | - for pos, column in enumerate(self.columns) |
366 | | - if column not in self.model_fields |
367 | | - ] |
368 | | - model_init_order = [self.columns.index(converter(f.column)) for f in model_init_fields] |
369 | | - model_init_names = [f.attname for f in model_init_fields] |
370 | | - return model_init_names, model_init_order, annotation_fields |
371 | | - |
372 | | - def prefetch_related(self, *lookups): |
373 | | - """Same as QuerySet.prefetch_related()""" |
374 | | - clone = self._clone() |
375 | | - if lookups == (None,): |
376 | | - clone._prefetch_related_lookups = () |
377 | | - else: |
378 | | - clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups |
379 | | - return clone |
380 | | - |
381 | | - def _prefetch_related_objects(self): |
382 | | - prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups) |
383 | | - self._prefetch_done = True |
384 | | - |
385 | | - def _clone(self): |
386 | | - """Same as QuerySet._clone()""" |
387 | | - c = self.__class__( |
388 | | - self.raw_query, |
389 | | - model=self.model, |
390 | | - query=self.query, |
391 | | - params=self.params, |
392 | | - translations=self.translations, |
393 | | - using=self._db, |
394 | | - hints=self._hints, |
395 | | - ) |
396 | | - c._prefetch_related_lookups = self._prefetch_related_lookups[:] |
397 | | - return c |
398 | | - |
399 | | - def _fetch_all(self): |
400 | | - if self._result_cache is None: |
401 | | - self._result_cache = list(self.iterator()) |
402 | | - if self._prefetch_related_lookups and not self._prefetch_done: |
403 | | - self._prefetch_related_objects() |
404 | | - |
405 | | - def __len__(self): |
406 | | - self._fetch_all() |
407 | | - return len(self._result_cache) |
408 | | - |
409 | | - def __bool__(self): |
410 | | - self._fetch_all() |
411 | | - return bool(self._result_cache) |
412 | | - |
413 | | - def __iter__(self): |
414 | | - self._fetch_all() |
415 | | - return iter(self._result_cache) |
416 | | - |
417 | | - def __aiter__(self): |
418 | | - # Remember, __aiter__ itself is synchronous, it's the thing it returns |
419 | | - # that is async! |
420 | | - async def generator(): |
421 | | - await sync_to_async(self._fetch_all)() |
422 | | - for item in self._result_cache: |
423 | | - yield item |
424 | | - |
425 | | - return generator() |
426 | | - |
427 | | - def iterator(self): |
428 | | - yield from RawModelIterable(self) |
429 | | - |
430 | | - def __repr__(self): |
431 | | - return "<%s: %s>" % (self.__class__.__name__, self.query) |
432 | | - |
433 | | - def __getitem__(self, k): |
434 | | - return list(self)[k] |
435 | | - |
436 | | - @property |
437 | | - def db(self): |
438 | | - """Return the database used if this query is executed now.""" |
439 | | - return self._db or router.db_for_read(self.model, **self._hints) |
440 | | - |
441 | | - def using(self, alias): |
442 | | - """Select the database this RawQuerySet should execute against.""" |
443 | | - return RawQuerySet( |
444 | | - self.raw_query, |
445 | | - model=self.model, |
446 | | - query=self.query.chain(using=alias), |
447 | | - params=self.params, |
448 | | - translations=self.translations, |
449 | | - using=alias, |
450 | | - ) |
451 | | - |
452 | | - @cached_property |
453 | | - def columns(self): |
454 | | - """ |
455 | | - A list of model field names in the order they'll appear in the |
456 | | - query results. |
457 | | - """ |
458 | | - columns = self.query.get_columns() |
459 | | - # Adjust any column names which don't match field names |
460 | | - for query_name, model_name in self.translations.items(): |
461 | | - # Ignore translations for nonexistent column names |
462 | | - try: |
463 | | - index = columns.index(query_name) |
464 | | - except ValueError: |
465 | | - pass |
466 | | - else: |
467 | | - columns[index] = model_name |
468 | | - return columns |
469 | | - |
470 | | - @cached_property |
471 | | - def model_fields(self): |
472 | | - """A dict mapping column names to model field names.""" |
473 | | - converter = connections[self.db].introspection.identifier_converter |
474 | | - model_fields = {} |
475 | | - for field in self.model._meta.fields: |
476 | | - name, column = field.get_attname_column() |
477 | | - model_fields[converter(column)] = field |
478 | | - return model_fields |
479 | | - |
480 | | - |
481 | | -class RawQuery: |
482 | | - """A single raw SQL query.""" |
483 | | - |
484 | | - def __init__(self, sql, using, params=()): |
485 | | - self.params = params |
486 | | - self.sql = sql |
487 | | - self.using = using |
488 | | - self.cursor = None |
489 | | - |
490 | | - # Mirror some properties of a normal query so that |
491 | | - # the compiler can be used to process results. |
492 | | - self.low_mark, self.high_mark = 0, None # Used for offset/limit |
493 | | - self.extra_select = {} |
494 | | - self.annotation_select = {} |
495 | | - |
496 | | - def chain(self, using): |
497 | | - return self.clone(using) |
498 | | - |
499 | | - def clone(self, using): |
500 | | - return RawQuery(self.sql, using, params=self.params) |
501 | | - |
502 | | - def get_columns(self): |
503 | | - if self.cursor is None: |
504 | | - self._execute_query() |
505 | | - converter = connections[self.using].introspection.identifier_converter |
506 | | - return [converter(column_meta[0]) for column_meta in self.cursor.description] |
507 | | - |
508 | | - def __iter__(self): |
509 | | - # Always execute a new query for a new iterator. |
510 | | - # This could be optimized with a cache at the expense of RAM. |
511 | | - self._execute_query() |
512 | | - if not connections[self.using].features.can_use_chunked_reads: |
513 | | - # If the database can't use chunked reads we need to make sure we |
514 | | - # evaluate the entire query up front. |
515 | | - result = list(self.cursor) |
516 | | - else: |
517 | | - result = self.cursor |
518 | | - return iter(result) |
519 | | - |
520 | | - def __repr__(self): |
521 | | - return "<%s: %s>" % (self.__class__.__name__, self) |
522 | | - |
523 | | - @property |
524 | | - def params_type(self): |
525 | | - if self.params is None: |
526 | | - return None |
527 | | - return dict if isinstance(self.params, Mapping) else tuple |
528 | | - |
529 | | - def __str__(self): |
530 | | - if self.params_type is None: |
531 | | - return self.sql |
532 | | - return self.sql % self.params_type(self.params) |
533 | | - |
534 | | - def _execute_query(self): |
535 | | - connection = connections[self.using] |
536 | | - |
537 | | - # Adapt parameters to the database, as much as possible considering |
538 | | - # that the target type isn't known. See #17755. |
539 | | - params_type = self.params_type |
540 | | - adapter = connection.ops.adapt_unknown_value |
541 | | - if params_type is tuple: |
542 | | - params = tuple(adapter(val) for val in self.params) |
543 | | - elif params_type is dict: |
544 | | - params = {key: adapter(val) for key, val in self.params.items()} |
545 | | - elif params_type is None: |
546 | | - params = None |
547 | | - else: |
548 | | - raise RuntimeError("Unexpected params type: %s" % params_type) |
549 | | - |
550 | | - self.cursor = connection.cursor() |
551 | | - self.cursor.execute(self.sql, params) |
552 | | - |
553 | | - |
554 | | -class RawModelIterable(BaseIterable): |
555 | | - """ |
556 | | - Iterable that yields a model instance for each row from a raw queryset. |
557 | | - """ |
558 | | - |
559 | | - def __iter__(self): |
560 | | - # Cache some things for performance reasons outside the loop. |
561 | | - db = self.queryset.db |
562 | | - query = self.queryset.query |
563 | | - connection = connections[db] |
564 | | - compiler = connection.ops.compiler("SQLCompiler")(query, connection, db) |
565 | | - query_iterator = iter(query) |
566 | | - |
567 | | - try: |
568 | | - ( |
569 | | - model_init_names, |
570 | | - model_init_pos, |
571 | | - annotation_fields, |
572 | | - ) = self.queryset.resolve_model_init_order() |
573 | | - model_cls = self.queryset.model |
574 | | - if model_cls._meta.pk.attname not in model_init_names: |
575 | | - raise exceptions.FieldDoesNotExist("Raw query must include the primary key") |
576 | | - fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns] |
577 | | - converters = compiler.get_converters( |
578 | | - [f.get_col(f.model._meta.db_table) if f else None for f in fields] |
579 | | - ) |
580 | | - if converters: |
581 | | - query_iterator = compiler.apply_converters(query_iterator, converters) |
582 | | - for values in query_iterator: |
583 | | - # Associate fields to values |
584 | | - model_init_values = [values[pos] for pos in model_init_pos] |
585 | | - instance = model_cls.from_db(db, model_init_names, model_init_values) |
586 | | - if annotation_fields: |
587 | | - for column, pos in annotation_fields: |
588 | | - setattr(instance, column, values[pos]) |
589 | | - yield instance |
590 | | - finally: |
591 | | - # Done iterating the Query. If it has its own cursor, close it. |
592 | | - if hasattr(query, "cursor") and query.cursor: |
593 | | - query.cursor.close() |
0 commit comments