From f3bd2c27bd83d1141cfb92ed5c88d1e2d2a388d1 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU <68415893+jason810496@users.noreply.github.com> Date: Mon, 21 Oct 2024 17:23:13 +0800 Subject: [PATCH] AIP-84 | Public list tags API (#42959) * AIP-84 | Public list tags API * refactor: upd resp schema, use paginated_select * refactor(test): move test to routers/public folder * refactor: remove OrderBy param, use SortParm --- airflow/api_fastapi/common/parameters.py | 13 ++ .../core_api/openapi/v1-generated.yaml | 78 +++++++++++ .../core_api/routes/public/dags.py | 39 +++++- .../api_fastapi/core_api/serializers/dags.py | 7 + airflow/ui/openapi-gen/queries/common.ts | 25 ++++ airflow/ui/openapi-gen/queries/prefetch.ts | 35 +++++ airflow/ui/openapi-gen/queries/queries.ts | 44 +++++++ airflow/ui/openapi-gen/queries/suspense.ts | 44 +++++++ .../ui/openapi-gen/requests/schemas.gen.ts | 20 +++ .../ui/openapi-gen/requests/services.gen.ts | 33 +++++ airflow/ui/openapi-gen/requests/types.gen.ts | 40 ++++++ .../core_api/routes/public/test_dags.py | 122 +++++++++++++++++- 12 files changed, 497 insertions(+), 3 deletions(-) diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index 4aa8335905ca00..9b265c7583a7ec 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -265,6 +265,17 @@ def depends(self, last_dag_run_state: DagRunState | None = None) -> _LastDagRunS return self.set_value(last_dag_run_state) +class _DagTagNamePatternSearch(_SearchParam): + """Search on dag_tag.name.""" + + def __init__(self, skip_none: bool = True) -> None: + super().__init__(DagTag.name, skip_none) + + def depends(self, tag_name_pattern: str | None = None) -> _DagTagNamePatternSearch: + tag_name_pattern = super().transform_aliases(tag_name_pattern) + return self.set_value(tag_name_pattern) + + def _safe_parse_datetime(date_to_check: str) -> datetime: """ Parse datetime and raise error for invalid dates. @@ -299,3 +310,5 @@ def _safe_parse_datetime(date_to_check: str) -> datetime: QueryOwnersFilter = Annotated[_OwnersFilter, Depends(_OwnersFilter().depends)] # DagRun QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)] +# DAGTags +QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)] diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 88dc7428bed6b9..325a6354de2b25 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -291,6 +291,68 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/tags: + get: + tags: + - DAG + summary: Get Dag Tags + description: Get all DAG tags. + operationId: get_dag_tags + parameters: + - name: limit + in: query + required: false + schema: + type: integer + default: 100 + title: Limit + - name: offset + in: query + required: false + schema: + type: integer + default: 0 + title: Offset + - name: order_by + in: query + required: false + schema: + type: string + default: name + title: Order By + - name: tag_name_pattern + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Tag Name Pattern + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DAGTagCollectionResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}: get: tags: @@ -1713,6 +1775,22 @@ components: - dataset_triggered title: DAGRunTypes description: DAG Run Types for responses. + DAGTagCollectionResponse: + properties: + tags: + items: + type: string + type: array + title: Tags + total_entries: + type: integer + title: Total Entries + type: object + required: + - tags + - total_entries + title: DAGTagCollectionResponse + description: DAG Tags Collection serializer for responses. DagProcessorInfoSchema: properties: status: diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index 81293211fc92e8..c7b753b5cdbd9b 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -18,7 +18,7 @@ from __future__ import annotations from fastapi import Depends, HTTPException, Query, Request, Response -from sqlalchemy import update +from sqlalchemy import select, update from sqlalchemy.orm import Session from typing_extensions import Annotated @@ -32,6 +32,7 @@ QueryDagDisplayNamePatternSearch, QueryDagIdPatternSearch, QueryDagIdPatternSearchWithNone, + QueryDagTagPatternSearch, QueryLastDagRunStateFilter, QueryLimit, QueryOffset, @@ -48,9 +49,10 @@ DAGDetailsResponse, DAGPatchBody, DAGResponse, + DAGTagCollectionResponse, ) from airflow.exceptions import AirflowException, DagNotFound -from airflow.models import DAG, DagModel +from airflow.models import DAG, DagModel, DagTag dags_router = AirflowRouter(tags=["DAG"], prefix="/dags") @@ -95,6 +97,39 @@ async def get_dags( ) +@dags_router.get( + "/tags", + responses=create_openapi_http_exception_doc([401, 403]), +) +async def get_dag_tags( + limit: QueryLimit, + offset: QueryOffset, + order_by: Annotated[ + SortParam, + Depends( + SortParam( + ["name"], + DagTag, + ).dynamic_depends() + ), + ], + tag_name_pattern: QueryDagTagPatternSearch, + session: Annotated[Session, Depends(get_session)], +) -> DAGTagCollectionResponse: + """Get all DAG tags.""" + base_select = select(DagTag.name).group_by(DagTag.name) + dag_tags_select, total_entries = paginated_select( + base_select=base_select, + filters=[tag_name_pattern], + order_by=order_by, + offset=offset, + limit=limit, + session=session, + ) + dag_tags = session.execute(dag_tags_select).scalars().all() + return DAGTagCollectionResponse(tags=[dag_tag for dag_tag in dag_tags], total_entries=total_entries) + + @dags_router.get("/{dag_id}", responses=create_openapi_http_exception_doc([400, 401, 403, 404, 422])) async def get_dag( dag_id: str, session: Annotated[Session, Depends(get_session)], request: Request diff --git a/airflow/api_fastapi/core_api/serializers/dags.py b/airflow/api_fastapi/core_api/serializers/dags.py index c9d48aac222eb2..39e85ea8c6f0ea 100644 --- a/airflow/api_fastapi/core_api/serializers/dags.py +++ b/airflow/api_fastapi/core_api/serializers/dags.py @@ -156,3 +156,10 @@ def get_params(cls, params: abc.MutableMapping | None) -> dict | None: def concurrency(self) -> int: """Return max_active_tasks as concurrency.""" return self.max_active_tasks + + +class DAGTagCollectionResponse(BaseModel): + """DAG Tags Collection serializer for responses.""" + + tags: list[str] + total_entries: int diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 98317d58ee4f3a..5e950de8447e95 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -102,6 +102,31 @@ export const UseDagServiceGetDagsKeyFn = ( }, ]), ]; +export type DagServiceGetDagTagsDefaultResponse = Awaited< + ReturnType +>; +export type DagServiceGetDagTagsQueryResult< + TData = DagServiceGetDagTagsDefaultResponse, + TError = unknown, +> = UseQueryResult; +export const useDagServiceGetDagTagsKey = "DagServiceGetDagTags"; +export const UseDagServiceGetDagTagsKeyFn = ( + { + limit, + offset, + orderBy, + tagNamePattern, + }: { + limit?: number; + offset?: number; + orderBy?: string; + tagNamePattern?: string; + } = {}, + queryKey?: Array, +) => [ + useDagServiceGetDagTagsKey, + ...(queryKey ?? [{ limit, offset, orderBy, tagNamePattern }]), +]; export type DagServiceGetDagDefaultResponse = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index af05a26fc26183..8807201d429922 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -129,6 +129,41 @@ export const prefetchUseDagServiceGetDags = ( tags, }), }); +/** + * Get Dag Tags + * Get all DAG tags. + * @param data The data for the request. + * @param data.limit + * @param data.offset + * @param data.orderBy + * @param data.tagNamePattern + * @returns DAGTagCollectionResponse Successful Response + * @throws ApiError + */ +export const prefetchUseDagServiceGetDagTags = ( + queryClient: QueryClient, + { + limit, + offset, + orderBy, + tagNamePattern, + }: { + limit?: number; + offset?: number; + orderBy?: string; + tagNamePattern?: string; + } = {}, +) => + queryClient.prefetchQuery({ + queryKey: Common.UseDagServiceGetDagTagsKeyFn({ + limit, + offset, + orderBy, + tagNamePattern, + }), + queryFn: () => + DagService.getDagTags({ limit, offset, orderBy, tagNamePattern }), + }); /** * Get Dag * Get basic information about a DAG. diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index bf9a744f6cb53e..ac6939d9f6f7a2 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -157,6 +157,50 @@ export const useDagServiceGetDags = < }) as TData, ...options, }); +/** + * Get Dag Tags + * Get all DAG tags. + * @param data The data for the request. + * @param data.limit + * @param data.offset + * @param data.orderBy + * @param data.tagNamePattern + * @returns DAGTagCollectionResponse Successful Response + * @throws ApiError + */ +export const useDagServiceGetDagTags = < + TData = Common.DagServiceGetDagTagsDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + limit, + offset, + orderBy, + tagNamePattern, + }: { + limit?: number; + offset?: number; + orderBy?: string; + tagNamePattern?: string; + } = {}, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useQuery({ + queryKey: Common.UseDagServiceGetDagTagsKeyFn( + { limit, offset, orderBy, tagNamePattern }, + queryKey, + ), + queryFn: () => + DagService.getDagTags({ + limit, + offset, + orderBy, + tagNamePattern, + }) as TData, + ...options, + }); /** * Get Dag * Get basic information about a DAG. diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index fad0d5b7a5a8e2..1e497418893e8a 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -151,6 +151,50 @@ export const useDagServiceGetDagsSuspense = < }) as TData, ...options, }); +/** + * Get Dag Tags + * Get all DAG tags. + * @param data The data for the request. + * @param data.limit + * @param data.offset + * @param data.orderBy + * @param data.tagNamePattern + * @returns DAGTagCollectionResponse Successful Response + * @throws ApiError + */ +export const useDagServiceGetDagTagsSuspense = < + TData = Common.DagServiceGetDagTagsDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + limit, + offset, + orderBy, + tagNamePattern, + }: { + limit?: number; + offset?: number; + orderBy?: string; + tagNamePattern?: string; + } = {}, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useSuspenseQuery({ + queryKey: Common.UseDagServiceGetDagTagsKeyFn( + { limit, offset, orderBy, tagNamePattern }, + queryKey, + ), + queryFn: () => + DagService.getDagTags({ + limit, + offset, + orderBy, + tagNamePattern, + }) as TData, + ...options, + }); /** * Get Dag * Get basic information about a DAG. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 23458a325883de..aa6437118eeeaa 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -1013,6 +1013,26 @@ export const $DAGRunTypes = { description: "DAG Run Types for responses.", } as const; +export const $DAGTagCollectionResponse = { + properties: { + tags: { + items: { + type: "string", + }, + type: "array", + title: "Tags", + }, + total_entries: { + type: "integer", + title: "Total Entries", + }, + }, + type: "object", + required: ["tags", "total_entries"], + title: "DAGTagCollectionResponse", + description: "DAG Tags Collection serializer for responses.", +} as const; + export const $DagProcessorInfoSchema = { properties: { status: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 0fa959e9725faa..d6f46b283faa13 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -11,6 +11,8 @@ import type { GetDagsResponse, PatchDagsData, PatchDagsResponse, + GetDagTagsData, + GetDagTagsResponse, GetDagData, GetDagResponse, PatchDagData, @@ -186,6 +188,37 @@ export class DagService { }); } + /** + * Get Dag Tags + * Get all DAG tags. + * @param data The data for the request. + * @param data.limit + * @param data.offset + * @param data.orderBy + * @param data.tagNamePattern + * @returns DAGTagCollectionResponse Successful Response + * @throws ApiError + */ + public static getDagTags( + data: GetDagTagsData = {}, + ): CancelablePromise { + return __request(OpenAPI, { + method: "GET", + url: "/public/dags/tags", + query: { + limit: data.limit, + offset: data.offset, + order_by: data.orderBy, + tag_name_pattern: data.tagNamePattern, + }, + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 422: "Validation Error", + }, + }); + } + /** * Get Dag * Get basic information about a DAG. diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 8cbc0b7e0dabd5..c795fce3e540d4 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -177,6 +177,14 @@ export type DAGRunTypes = { dataset_triggered: number; }; +/** + * DAG Tags Collection serializer for responses. + */ +export type DAGTagCollectionResponse = { + tags: Array; + total_entries: number; +}; + /** * Schema for DagProcessor info. */ @@ -387,6 +395,15 @@ export type PatchDagsData = { export type PatchDagsResponse = DAGCollectionResponse; +export type GetDagTagsData = { + limit?: number; + offset?: number; + orderBy?: string; + tagNamePattern?: string | null; +}; + +export type GetDagTagsResponse = DAGTagCollectionResponse; + export type GetDagData = { dagId: string; }; @@ -577,6 +594,29 @@ export type $OpenApiTs = { }; }; }; + "/public/dags/tags": { + get: { + req: GetDagTagsData; + res: { + /** + * Successful Response + */ + 200: DAGTagCollectionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/public/dags/{dag_id}": { get: { req: GetDagData; diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index edc350c27b84ac..a48040482023ea 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -21,7 +21,7 @@ import pendulum import pytest -from airflow.models.dag import DagModel +from airflow.models.dag import DagModel, DagTag from airflow.models.dagrun import DagRun from airflow.operators.empty import EmptyOperator from airflow.utils.session import provide_session @@ -88,6 +88,11 @@ def _create_deactivated_paused_dag(self, session=None): session.add(dagrun_failed) session.add(dagrun_success) + def _create_dag_tags(self, session=None): + session.add(DagTag(dag_id=DAG1_ID, name="tag_2")) + session.add(DagTag(dag_id=DAG2_ID, name="tag_1")) + session.add(DagTag(dag_id=DAG3_ID, name="tag_1")) + @pytest.fixture(autouse=True) @provide_session def setup(self, dag_maker, session=None) -> None: @@ -118,6 +123,7 @@ def setup(self, dag_maker, session=None) -> None: EmptyOperator(task_id=TASK_ID) self._create_deactivated_paused_dag(session) + self._create_dag_tags(session) dag_maker.dagbag.sync_to_db() dag_maker.dag_model.has_task_concurrency_limits = True @@ -386,6 +392,120 @@ def test_get_dag(self, test_client, query_params, dag_id, expected_status_code, assert res_json == expected +class TestGetDagTags(TestDagEndpoint): + """Unit tests for Get DAG Tags.""" + + @pytest.mark.parametrize( + "query_params, expected_status_code, expected_dag_tags, expected_total_entries", + [ + # test with offset, limit, and without any tag_name_pattern + ( + {}, + 200, + [ + "example", + "tag_1", + "tag_2", + ], + 3, + ), + ( + {"offset": 1}, + 200, + [ + "tag_1", + "tag_2", + ], + 3, + ), + ( + {"limit": 2}, + 200, + [ + "example", + "tag_1", + ], + 3, + ), + ( + {"offset": 1, "limit": 2}, + 200, + [ + "tag_1", + "tag_2", + ], + 3, + ), + # test with tag_name_pattern + ( + {"tag_name_pattern": "invalid"}, + 200, + [], + 0, + ), + ( + {"tag_name_pattern": "1"}, + 200, + ["tag_1"], + 1, + ), + ( + {"tag_name_pattern": "tag%"}, + 200, + ["tag_1", "tag_2"], + 2, + ), + # test order_by + ( + {"order_by": "-name"}, + 200, + ["tag_2", "tag_1", "example"], + 3, + ), + # test all query params + ( + {"tag_name_pattern": "t%", "order_by": "-name", "offset": 1, "limit": 1}, + 200, + ["tag_1"], + 2, + ), + ( + {"tag_name_pattern": "~", "offset": 1, "limit": 2}, + 200, + ["tag_1", "tag_2"], + 3, + ), + # test invalid query params + ( + {"order_by": "dag_id"}, + 400, + None, + None, + ), + ( + {"order_by": "-dag_id"}, + 400, + None, + None, + ), + ], + ) + def test_get_dag_tags( + self, test_client, query_params, expected_status_code, expected_dag_tags, expected_total_entries + ): + response = test_client.get("/public/dags/tags", params=query_params) + assert response.status_code == expected_status_code + if expected_status_code != 200: + return + + res_json = response.json() + expected = { + "tags": expected_dag_tags, + "total_entries": expected_total_entries, + } + assert res_json == expected + + class TestDeleteDAG(TestDagEndpoint): """Unit tests for Delete DAG."""