Skip to content

added aggregation filter end to mongodb extensions #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions fastapi_pagination/ext/beanie.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__all__ = ["apaginate", "paginate"]

from copy import copy
from typing import Any, Optional, TypeVar, Union
from typing import Any, Literal, Optional, TypeVar, Union

from beanie import Document, PydanticObjectId
from beanie.odm.enums import SortDirection
Expand All @@ -16,6 +16,7 @@

from fastapi_pagination.api import apply_items_transformer, create_page
from fastapi_pagination.bases import AbstractParams, is_cursor, is_limit_offset
from fastapi_pagination.ext.utils import get_mongo_pipeline_filter_end
from fastapi_pagination.types import AdditionalData, AsyncItemsTransformer
from fastapi_pagination.utils import verify_params

Expand All @@ -42,6 +43,7 @@ async def apaginate( # noqa: C901, PLR0912, PLR0915
ignore_cache: bool = False,
fetch_links: bool = False,
lazy_parse: bool = False,
aggregation_filter_end: Optional[Union[int, Literal["auto"]]] = None,
**pymongo_kwargs: Any,
) -> Any:
params, raw_params = verify_params(params, "limit-offset", "cursor")
Expand Down Expand Up @@ -74,12 +76,21 @@ async def apaginate( # noqa: C901, PLR0912, PLR0915
},
},
)

aggregation_query.aggregation_pipeline.extend(
[
{"$facet": {"metadata": [{"$count": "total"}], "data": paginate_data}},
],
)
if aggregation_filter_end is not None:
if aggregation_filter_end == "auto":
aggregation_filter_end = get_mongo_pipeline_filter_end(aggregation_query.aggregation_pipeline)
filter_part = aggregation_query.aggregation_pipeline[:aggregation_filter_end]
transform_part = aggregation_query.aggregation_pipeline[aggregation_filter_end:]
aggregation_query.aggregation_pipeline = [
*filter_part,
{"$facet": {"metadata": [{"$count": "total"}], "data": [*paginate_data, *transform_part]}},
]
else:
aggregation_query.aggregation_pipeline.extend(
[
{"$facet": {"metadata": [{"$count": "total"}], "data": paginate_data}},
],
)
data = (await aggregation_query.to_list())[0]
items = data["data"]
try:
Expand Down Expand Up @@ -178,6 +189,7 @@ async def paginate(
ignore_cache: bool = False,
fetch_links: bool = False,
lazy_parse: bool = False,
aggregation_filter_end: Optional[int] = None,
**pymongo_kwargs: Any,
) -> Any:
return await apaginate(
Expand All @@ -191,5 +203,6 @@ async def paginate(
ignore_cache=ignore_cache,
fetch_links=fetch_links,
lazy_parse=lazy_parse,
aggregation_filter_end=aggregation_filter_end,
**pymongo_kwargs,
)
25 changes: 19 additions & 6 deletions fastapi_pagination/ext/bunnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["paginate"]

from typing import Any, Optional, TypeVar, Union
from typing import Any, Literal, Optional, TypeVar, Union

from bunnet import Document
from bunnet.odm.enums import SortDirection
Expand All @@ -10,6 +10,7 @@

from fastapi_pagination.api import apply_items_transformer, create_page
from fastapi_pagination.bases import AbstractParams
from fastapi_pagination.ext.utils import get_mongo_pipeline_filter_end
from fastapi_pagination.types import AdditionalData, SyncItemsTransformer
from fastapi_pagination.utils import verify_params

Expand All @@ -28,6 +29,7 @@ def paginate(
ignore_cache: bool = False,
fetch_links: bool = False,
lazy_parse: bool = False,
aggregation_filter_end: Optional[Union[int, Literal["auto"]]] = None,
**pymongo_kwargs: Any,
) -> Any:
params, raw_params = verify_params(params, "limit-offset")
Expand All @@ -40,11 +42,22 @@ def paginate(
if raw_params.offset is not None:
paginate_data.append({"$skip": raw_params.offset})

aggregation_query.aggregation_pipeline.extend(
[
{"$facet": {"metadata": [{"$count": "total"}], "data": paginate_data}},
],
)
if aggregation_filter_end is not None:
if aggregation_filter_end == "auto":
aggregation_filter_end = get_mongo_pipeline_filter_end(aggregation_query.aggregation_pipeline)
filter_part = aggregation_query.aggregation_pipeline[:aggregation_filter_end]
transform_part = aggregation_query.aggregation_pipeline[aggregation_filter_end:]
aggregation_query.aggregation_pipeline = [
*filter_part,
{"$facet": {"metadata": [{"$count": "total"}], "data": [*paginate_data, *transform_part]}},
]
else:
aggregation_query.aggregation_pipeline.extend(
[
{"$facet": {"metadata": [{"$count": "total"}], "data": paginate_data}},
],
)

data = aggregation_query.to_list()[0]
items = data["data"]
try:
Expand Down
11 changes: 10 additions & 1 deletion fastapi_pagination/ext/motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
"paginate_aggregate",
]

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

from motor.core import AgnosticCollection
from typing_extensions import TypeAlias, deprecated

from fastapi_pagination.api import apply_items_transformer, create_page
from fastapi_pagination.bases import AbstractParams
from fastapi_pagination.ext.utils import get_mongo_pipeline_filter_end
from fastapi_pagination.types import AdditionalData, AsyncItemsTransformer
from fastapi_pagination.utils import verify_params

Expand Down Expand Up @@ -58,6 +59,7 @@ async def apaginate_aggregate(
*,
transformer: Optional[AsyncItemsTransformer] = None,
additional_data: Optional[AdditionalData] = None,
aggregation_filter_end: Optional[Union[int, Literal["auto"]]] = None,
) -> Any:
params, raw_params = verify_params(params, "limit-offset")
aggregate_pipeline = aggregate_pipeline or []
Expand All @@ -68,6 +70,13 @@ async def apaginate_aggregate(
if raw_params.offset is not None:
paginate_data.append({"$skip": raw_params.offset})

if aggregation_filter_end is not None:
if aggregation_filter_end == "auto":
aggregation_filter_end = get_mongo_pipeline_filter_end(aggregate_pipeline)
transform_part = aggregate_pipeline[:aggregation_filter_end]
aggregate_pipeline = aggregate_pipeline[aggregation_filter_end:]
paginate_data.extend(transform_part)

cursor = collection.aggregate(
[
*aggregate_pipeline,
Expand Down
27 changes: 27 additions & 0 deletions fastapi_pagination/ext/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = [
"generic_query_apply_params",
"get_mongo_pipeline_filter_end",
"len_or_none",
"unwrap_scalars",
"wrap_scalars",
Expand Down Expand Up @@ -55,3 +56,29 @@ def generic_query_apply_params(q: TAbstractQuery, params: RawParams) -> TAbstrac
q = q.offset(params.offset)

return q


def get_mongo_pipeline_filter_end(
aggregate_pipeline: list[dict[str, Any]],
) -> int:
"""
Get the index of the stage in the aggregation pipeline where the number or order
of documents in the pipeline no longer changes.
"""

# MongoDB aggregation pipeline stages that do not change the number or order
# of documents in the pipeline output.
transform_stages = [
"$addFields",
"$graphLookup",
"$lookup",
"$project",
"$replaceRoot",
"$replaceWith",
"$set",
"$unset",
]
for i, stage in enumerate(reversed(aggregate_pipeline)):
if any(stage_name not in transform_stages for stage_name in stage):
return len(aggregate_pipeline) - i
return 0
11 changes: 10 additions & 1 deletion tests/ext/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi_pagination.ext.utils import len_or_none, unwrap_scalars, wrap_scalars
from fastapi_pagination.ext.utils import get_mongo_pipeline_filter_end, len_or_none, unwrap_scalars, wrap_scalars


def test_len_or_none():
Expand All @@ -20,3 +20,12 @@ def test_wrap_scalars():
assert wrap_scalars([[]]) == [[]]
assert wrap_scalars([1, 2]) == [[1], [2]]
assert wrap_scalars([1, [2, 3]]) == [[1], [2, 3]]


def test_get_mongo_pipeline_filter_end():
assert get_mongo_pipeline_filter_end([]) == 0
assert get_mongo_pipeline_filter_end([{"$match": {}}]) == 1
assert get_mongo_pipeline_filter_end([{"$match": {}}, {"$project": {}}]) == 1
assert get_mongo_pipeline_filter_end([{"$match": {}}, {"$sort": {}}, {"$project": {}}]) == 2
assert get_mongo_pipeline_filter_end([{"$match": {}}, {"$project": {}}, {"$sort": {}}, {"$project": {}}]) == 3
assert get_mongo_pipeline_filter_end([{"$match": {}}, {"$project": {}}, {"$lookup": {}}, {"$project": {}}]) == 1