Skip to content

Commit

Permalink
refactor: remove OrderBy param, use SortParm
Browse files Browse the repository at this point in the history
  • Loading branch information
jason810496 committed Oct 18, 2024
1 parent 16f0a6d commit 74977b9
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 29 deletions.
23 changes: 1 addition & 22 deletions airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Generic, List, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar

from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
Expand Down Expand Up @@ -128,26 +128,6 @@ def transform_aliases(self, value: str | None) -> str | None:
return value


class _OrderByParam(BaseParam[str]):
"""Order result by specified attribute ascending or descending."""

def __init__(self, attribute: ColumnElement, skip_none: bool = True) -> None:
super().__init__(skip_none)
self.attribute: ColumnElement = attribute
self.value: Literal["asc", "desc"] | None = None

def to_orm(self, select: Select) -> Select:
if self.value is None and self.skip_none:
return select
asc_stmt = select.order_by(self.attribute.asc())
if self.value is None:
return asc_stmt
return asc_stmt if self.value == "asc" else select.order_by(self.attribute.desc())

def depends(self, order_by: str = "asc") -> _OrderByParam:
return self.set_value(order_by)


class _DagIdPatternSearch(_SearchParam):
"""Search on dag_id."""

Expand Down Expand Up @@ -331,5 +311,4 @@ def _safe_parse_datetime(date_to_check: str) -> datetime:
# DagRun
QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)]
# DAGTags
QueryDagTagOrderBy = Annotated[_OrderByParam, Depends(_OrderByParam(DagTag.name, skip_none=False).depends)]
QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)]
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ paths:
required: false
schema:
type: string
default: asc
default: name
title: Order By
- name: tag_name_pattern
in: query
Expand Down
15 changes: 11 additions & 4 deletions airflow/api_fastapi/core_api/routes/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from fastapi import Depends, HTTPException, Query, Request, Response
from sqlalchemy import distinct, select, update
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from typing_extensions import Annotated

Expand All @@ -32,7 +32,6 @@
QueryDagDisplayNamePatternSearch,
QueryDagIdPatternSearch,
QueryDagIdPatternSearchWithNone,
QueryDagTagOrderBy,
QueryDagTagPatternSearch,
QueryLastDagRunStateFilter,
QueryLimit,
Expand Down Expand Up @@ -105,12 +104,20 @@ async def get_dags(
async def get_dag_tags(
limit: QueryLimit,
offset: QueryOffset,
order_by: QueryDagTagOrderBy,
order_by: Annotated[
SortParam,
Depends(
SortParam(
["name"],
DagTag,
).dynamic_depends()
),
],
tag_name_pattern: QueryDagTagPatternSearch,
session: Annotated[Session, Depends(get_session)],
) -> DAGTagCollectionResponse:
"""Get all DAG tags."""
base_select = select(distinct(DagTag.name))
base_select = select(DagTag.name).group_by(DagTag.name)
dag_tags_select, total_entries = paginated_select(
base_select=base_select,
filters=[tag_name_pattern],
Expand Down
17 changes: 15 additions & 2 deletions tests/api_fastapi/core_api/routes/public/test_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,14 +457,14 @@ class TestGetDagTags(TestDagEndpoint):
),
# test order_by
(
{"order_by": "desc"},
{"order_by": "-name"},
200,
["tag_2", "tag_1", "example"],
3,
),
# test all query params
(
{"tag_name_pattern": "t%", "order_by": "desc", "offset": 1, "limit": 1},
{"tag_name_pattern": "t%", "order_by": "-name", "offset": 1, "limit": 1},
200,
["tag_1"],
2,
Expand All @@ -475,6 +475,19 @@ class TestGetDagTags(TestDagEndpoint):
["tag_1", "tag_2"],
3,
),
# test invalid query params
(
{"order_by": "dag_id"},
400,
None,
None,
),
(
{"order_by": "-dag_id"},
400,
None,
None,
),
],
)
def test_get_dag_tags(
Expand Down

0 comments on commit 74977b9

Please sign in to comment.