Skip to content

Commit

Permalink
refactor: upd resp schema, use paginated_select
Browse files Browse the repository at this point in the history
  • Loading branch information
jason810496 committed Oct 18, 2024
1 parent e67c4f8 commit 0bb6154
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 163 deletions.
36 changes: 35 additions & 1 deletion airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, List, Literal, TypeVar

from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
Expand Down Expand Up @@ -128,6 +128,26 @@ def transform_aliases(self, value: str | None) -> str | None:
return value


class _OrderByParam(BaseParam[str]):
"""Order result by specified attribute ascending or descending."""

def __init__(self, attribute: ColumnElement, skip_none: bool = True) -> None:
super().__init__(skip_none)
self.attribute: ColumnElement = attribute
self.value: Literal["asc", "desc"] | None = None

def to_orm(self, select: Select) -> Select:
if self.value is None and self.skip_none:
return select
asc_stmt = select.order_by(self.attribute.asc())
if self.value is None:
return asc_stmt
return asc_stmt if self.value == "asc" else select.order_by(self.attribute.desc())

def depends(self, order_by: str = "asc") -> _OrderByParam:
return self.set_value(order_by)


class _DagIdPatternSearch(_SearchParam):
"""Search on dag_id."""

Expand Down Expand Up @@ -265,6 +285,17 @@ def depends(self, last_dag_run_state: DagRunState | None = None) -> _LastDagRunS
return self.set_value(last_dag_run_state)


class _DagTagNamePatternSearch(_SearchParam):
"""Search on dag_tag.name."""

def __init__(self, skip_none: bool = True) -> None:
super().__init__(DagTag.name, skip_none)

def depends(self, tag_name_pattern: str | None = None) -> _DagTagNamePatternSearch:
tag_name_pattern = super().transform_aliases(tag_name_pattern)
return self.set_value(tag_name_pattern)


def _safe_parse_datetime(date_to_check: str) -> datetime:
"""
Parse datetime and raise error for invalid dates.
Expand Down Expand Up @@ -299,3 +330,6 @@ def _safe_parse_datetime(date_to_check: str) -> datetime:
QueryOwnersFilter = Annotated[_OwnersFilter, Depends(_OwnersFilter().depends)]
# DagRun
QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)]
# DAGTags
QueryDagTagOrderBy = Annotated[_OrderByParam, Depends(_OrderByParam(DagTag.name).depends)]
QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)]
66 changes: 40 additions & 26 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,30 +299,42 @@ paths:
description: Get all DAG tags.
operationId: get_dag_tags
parameters:
- name: tags
- name: limit
in: query
required: false
schema:
type: array
items:
type: string
title: Tags
type: integer
default: 100
title: Limit
- name: offset
in: query
required: false
schema:
type: integer
default: 0
title: Offset
- name: order_by
in: query
required: false
schema:
type: string
default: asc
title: Order By
- name: tag_name_pattern
in: query
required: false
schema:
anyOf:
- type: string
- type: 'null'
title: Tag Name Pattern
responses:
'200':
description: Successful Response
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/DAGTagResponse'
title: Response Get Dag Tags
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
$ref: '#/components/schemas/DAGTagCollectionResponse'
'401':
content:
application/json:
Expand Down Expand Up @@ -1699,20 +1711,22 @@ components:
- latest_dag_processor_heartbeat
title: DagProcessorInfoSchema
description: Schema for DagProcessor info.
DAGTagResponse:
DAGTagCollectionResponse:
properties:
name:
type: string
title: Name
selected:
type: boolean
title: Selected
tags:
items:
type: string
type: array
title: Tags
total_entries:
type: integer
title: Total Entries
type: object
required:
- name
- selected
title: DAGTagResponse
description: DAG Tags serializer for responses.
- tags
- total_entries
title: DAGTagCollectionResponse
description: DAG Tags Collection serializer for responses.
DagRunState:
type: string
enum:
Expand Down
37 changes: 21 additions & 16 deletions airflow/api_fastapi/core_api/routes/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from fastapi import Depends, HTTPException, Query, Request, Response
from sqlalchemy import distinct, update
from sqlalchemy import distinct, select, update
from sqlalchemy.orm import Session
from typing_extensions import Annotated

Expand All @@ -32,6 +32,8 @@
QueryDagDisplayNamePatternSearch,
QueryDagIdPatternSearch,
QueryDagIdPatternSearchWithNone,
QueryDagTagOrderBy,
QueryDagTagPatternSearch,
QueryLastDagRunStateFilter,
QueryLimit,
QueryOffset,
Expand All @@ -48,7 +50,7 @@
DAGDetailsResponse,
DAGPatchBody,
DAGResponse,
DAGTagResponse,
DAGTagCollectionResponse,
)
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models import DAG, DagModel, DagTag
Expand Down Expand Up @@ -98,24 +100,27 @@ async def get_dags(

@dags_router.get(
"/tags",
response_model=list[DAGTagResponse],
responses=create_openapi_http_exception_doc([400, 401, 403]),
responses=create_openapi_http_exception_doc([401, 403]),
)
async def get_dag_tags(
tags: QueryTagsFilter,
limit: QueryLimit,
offset: QueryOffset,
order_by: QueryDagTagOrderBy,
tag_name_pattern: QueryDagTagPatternSearch,
session: Annotated[Session, Depends(get_session)],
) -> list[DAGTagResponse]:
) -> DAGTagCollectionResponse:
"""Get all DAG tags."""
dag_tag_names = session.query(distinct(DagTag.name)).order_by(DagTag.name).all()
if not dag_tag_names:
return []
selected_dag_tags = {}
if tags.value:
selected_dag_tags = {tag: True for tag in tags.value}
return [
DAGTagResponse(name=tag_name_row[0], selected=selected_dag_tags.get(tag_name_row[0], False))
for tag_name_row in dag_tag_names
]
base_select = select(distinct(DagTag.name))
dag_tags_select, total_entries = paginated_select(
base_select=base_select,
filters=[tag_name_pattern],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
dag_tags = session.execute(dag_tags_select).scalars().all()
return DAGTagCollectionResponse(tags=[dag_tag for dag_tag in dag_tags], total_entries=total_entries)


@dags_router.get("/{dag_id}", responses=create_openapi_http_exception_doc([400, 401, 403, 404, 422]))
Expand Down
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/serializers/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def concurrency(self) -> int:
return self.max_active_tasks


class DAGTagResponse(BaseModel):
"""DAG Tags serializer for responses."""
class DAGTagCollectionResponse(BaseModel):
"""DAG Tags Collection serializer for responses."""

name: str
selected: bool
tags: list[str]
total_entries: int
15 changes: 12 additions & 3 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,21 @@ export type DagServiceGetDagTagsQueryResult<
export const useDagServiceGetDagTagsKey = "DagServiceGetDagTags";
export const UseDagServiceGetDagTagsKeyFn = (
{
tags,
limit,
offset,
orderBy,
tagNamePattern,
}: {
tags?: string[];
limit?: number;
offset?: number;
orderBy?: string;
tagNamePattern?: string;
} = {},
queryKey?: Array<unknown>,
) => [useDagServiceGetDagTagsKey, ...(queryKey ?? [{ tags }])];
) => [
useDagServiceGetDagTagsKey,
...(queryKey ?? [{ limit, offset, orderBy, tagNamePattern }]),
];
export type DagServiceGetDagDefaultResponse = Awaited<
ReturnType<typeof DagService.getDag>
>;
Expand Down
27 changes: 21 additions & 6 deletions airflow/ui/openapi-gen/queries/prefetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,36 @@ export const prefetchUseDagServiceGetDags = (
* Get Dag Tags
* Get all DAG tags.
* @param data The data for the request.
* @param data.tags
* @returns DAGTagResponse Successful Response
* @param data.limit
* @param data.offset
* @param data.orderBy
* @param data.tagNamePattern
* @returns DAGTagCollectionResponse Successful Response
* @throws ApiError
*/
export const prefetchUseDagServiceGetDagTags = (
queryClient: QueryClient,
{
tags,
limit,
offset,
orderBy,
tagNamePattern,
}: {
tags?: string[];
limit?: number;
offset?: number;
orderBy?: string;
tagNamePattern?: string;
} = {},
) =>
queryClient.prefetchQuery({
queryKey: Common.UseDagServiceGetDagTagsKeyFn({ tags }),
queryFn: () => DagService.getDagTags({ tags }),
queryKey: Common.UseDagServiceGetDagTagsKeyFn({
limit,
offset,
orderBy,
tagNamePattern,
}),
queryFn: () =>
DagService.getDagTags({ limit, offset, orderBy, tagNamePattern }),
});
/**
* Get Dag
Expand Down
30 changes: 24 additions & 6 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,11 @@ export const useDagServiceGetDags = <
* Get Dag Tags
* Get all DAG tags.
* @param data The data for the request.
* @param data.tags
* @returns DAGTagResponse Successful Response
* @param data.limit
* @param data.offset
* @param data.orderBy
* @param data.tagNamePattern
* @returns DAGTagCollectionResponse Successful Response
* @throws ApiError
*/
export const useDagServiceGetDagTags = <
Expand All @@ -169,16 +172,31 @@ export const useDagServiceGetDagTags = <
TQueryKey extends Array<unknown> = unknown[],
>(
{
tags,
limit,
offset,
orderBy,
tagNamePattern,
}: {
tags?: string[];
limit?: number;
offset?: number;
orderBy?: string;
tagNamePattern?: string;
} = {},
queryKey?: TQueryKey,
options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">,
) =>
useQuery<TData, TError>({
queryKey: Common.UseDagServiceGetDagTagsKeyFn({ tags }, queryKey),
queryFn: () => DagService.getDagTags({ tags }) as TData,
queryKey: Common.UseDagServiceGetDagTagsKeyFn(
{ limit, offset, orderBy, tagNamePattern },
queryKey,
),
queryFn: () =>
DagService.getDagTags({
limit,
offset,
orderBy,
tagNamePattern,
}) as TData,
...options,
});
/**
Expand Down
Loading

0 comments on commit 0bb6154

Please sign in to comment.