Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-84 Migrate get connections to FastAPI API #42571 #42782

Merged
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API
@security.requires_access_connection("GET")
@format_parameters({"limit": check_limit})
@provide_session
@mark_fastapi_migration_done
def get_connections(
*,
limit: int,
Expand Down
82 changes: 79 additions & 3 deletions airflow/api_fastapi/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,66 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/connections/:
get:
tags:
- Connection
summary: Get Connections
description: Get all connection entries.
operationId: get_connections
parameters:
- name: limit
in: query
required: false
schema:
type: integer
default: 100
title: Limit
- name: offset
in: query
required: false
schema:
type: integer
default: 0
title: Offset
- name: order_by
in: query
required: false
schema:
type: string
default: id
title: Order By
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/ConnectionCollectionResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/variables/{variable_key}:
delete:
tags:
Expand Down Expand Up @@ -886,11 +946,27 @@ paths:
$ref: '#/components/schemas/HTTPValidationError'
components:
schemas:
ConnectionCollectionResponse:
properties:
connections:
items:
$ref: '#/components/schemas/ConnectionResponse'
type: array
title: Connections
total_entries:
type: integer
title: Total Entries
type: object
required:
- connections
- total_entries
title: ConnectionCollectionResponse
description: DAG Collection serializer for responses.
ConnectionResponse:
properties:
conn_id:
connection_id:
type: string
title: Conn Id
title: Connection Id
conn_type:
type: string
title: Conn Type
Expand Down Expand Up @@ -926,7 +1002,7 @@ components:
title: Extra
type: object
required:
- conn_id
- connection_id
- conn_type
- description
- host
Expand Down
48 changes: 40 additions & 8 deletions airflow/api_fastapi/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

from __future__ import annotations

import importlib
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Generic, List, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar

from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
from pydantic import AfterValidator
from sqlalchemy import case, or_
from sqlalchemy import Column, case, or_
from sqlalchemy.inspection import inspect
from typing_extensions import Annotated, Self

from airflow.models import Base, Connection
from airflow.models.dag import DagModel, DagTag
from airflow.models.dagrun import DagRun
from airflow.utils import timezone
Expand Down Expand Up @@ -154,11 +157,17 @@ class SortParam(BaseParam[str]):
attr_mapping = {
"last_run_state": DagRun.state,
"last_run_start_date": DagRun.start_date,
"connection_id": Connection.conn_id,
}

def __init__(self, allowed_attrs: list[str]) -> None:
def __init__(
self,
allowed_attrs: list[str],
model: Base,
) -> None:
super().__init__()
self.allowed_attrs = allowed_attrs
self.model = model
pierrejeambrun marked this conversation as resolved.
Show resolved Hide resolved

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
Expand All @@ -175,7 +184,9 @@ def to_orm(self, select: Select) -> Select:
f"the attribute does not exist on the model",
)

column = self.attr_mapping.get(lstriped_orderby, None) or getattr(DagModel, lstriped_orderby)
column: Column = self.attr_mapping.get(lstriped_orderby, None) or getattr(
self.model, lstriped_orderby
)

# MySQL does not support `nullslast`, and True/False ordering depends on the
# database implementation.
Expand All @@ -185,12 +196,33 @@ def to_orm(self, select: Select) -> Select:
select = select.order_by(None)

if self.value[0] == "-":
return select.order_by(nullscheck, column.desc(), DagModel.dag_id.desc())
return select.order_by(nullscheck, column.desc(), column.desc())
else:
return select.order_by(nullscheck, column.asc(), DagModel.dag_id.asc())
return select.order_by(nullscheck, column.asc(), column.asc())

def get_primary_key(self) -> str:
"""Get the primary key of the model of SortParam object."""
return inspect(self.model).primary_key[0].name

@staticmethod
def get_primary_key_of_given_model_string(model_string: str) -> str:
"""
Get the primary key of given 'airflow.models' class as a string. The class should have driven be from 'airflow.models.base'.

:param model_string: The string representation of the model class.
:return: The primary key of the model class.
"""
dynamic_return_model = getattr(importlib.import_module("airflow.models"), model_string)
return inspect(dynamic_return_model).primary_key[0].name

def depends(self, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use dynamic_depends, depends not implemented.")

def dynamic_depends(self) -> Callable:
def inner(order_by: str = self.get_primary_key()) -> SortParam:
return self.set_value(self.get_primary_key() if order_by == "" else order_by)

def depends(self, order_by: str = "dag_id") -> SortParam:
return self.set_value(order_by)
return inner


class _TagsFilter(BaseParam[List[str]]):
Expand Down
9 changes: 8 additions & 1 deletion airflow/api_fastapi/serializers/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class ConnectionResponse(BaseModel):
"""Connection serializer for responses."""

connection_id: str = Field(alias="conn_id")
connection_id: str = Field(serialization_alias="connection_id", validation_alias="conn_id")
pierrejeambrun marked this conversation as resolved.
Show resolved Hide resolved
conn_type: str
description: str | None
host: str | None
Expand All @@ -48,3 +48,10 @@ def redact_extra(cls, v: str | None) -> str | None:
except json.JSONDecodeError:
# we can't redact fields in an unstructured `extra`
return v


class ConnectionCollectionResponse(BaseModel):
"""DAG Collection serializer for responses."""

connections: list[ConnectionResponse]
total_entries: int
42 changes: 40 additions & 2 deletions airflow/api_fastapi/views/public/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from sqlalchemy.orm import Session
from typing_extensions import Annotated

from airflow.api_fastapi.db.common import get_session
from airflow.api_fastapi.db.common import get_session, paginated_select
from airflow.api_fastapi.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.serializers.connections import ConnectionResponse
from airflow.api_fastapi.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.serializers.connections import ConnectionCollectionResponse, ConnectionResponse
from airflow.api_fastapi.views.router import AirflowRouter
from airflow.models import Connection

Expand Down Expand Up @@ -63,3 +64,40 @@ async def get_connection(
raise HTTPException(404, f"The Connection with connection_id: `{connection_id}` was not found")

return ConnectionResponse.model_validate(connection, from_attributes=True)


@connections_router.get(
"/",
responses=create_openapi_http_exception_doc([401, 403, 404]),
)
async def get_connections(
limit: QueryLimit,
offset: QueryOffset,
order_by: Annotated[
SortParam,
Depends(
SortParam(
["connection_id", "conn_type", "description", "host", "port", "id"], Connection
).dynamic_depends()
),
],
session: Annotated[Session, Depends(get_session)],
) -> ConnectionCollectionResponse:
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
select(Connection),
[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

connections = session.scalars(connection_select).all()

return ConnectionCollectionResponse(
connections=[
ConnectionResponse.model_validate(connection, from_attributes=True) for connection in connections
],
total_entries=total_entries,
)
5 changes: 3 additions & 2 deletions airflow/api_fastapi/views/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ async def get_dags(
SortParam,
Depends(
SortParam(
["dag_id", "dag_display_name", "next_dagrun", "last_run_state", "last_run_start_date"]
).depends
["dag_id", "dag_display_name", "next_dagrun", "last_run_state", "last_run_start_date"],
DagModel,
).dynamic_depends()
),
],
session: Annotated[Session, Depends(get_session)],
Expand Down
24 changes: 24 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,30 @@ export const UseConnectionServiceGetConnectionKeyFn = (
useConnectionServiceGetConnectionKey,
...(queryKey ?? [{ connectionId }]),
];
export type ConnectionServiceGetConnectionsDefaultResponse = Awaited<
ReturnType<typeof ConnectionService.getConnections>
>;
export type ConnectionServiceGetConnectionsQueryResult<
TData = ConnectionServiceGetConnectionsDefaultResponse,
TError = unknown,
> = UseQueryResult<TData, TError>;
export const useConnectionServiceGetConnectionsKey =
"ConnectionServiceGetConnections";
export const UseConnectionServiceGetConnectionsKeyFn = (
{
limit,
offset,
orderBy,
}: {
limit?: number;
offset?: number;
orderBy?: string;
} = {},
queryKey?: Array<unknown>,
) => [
useConnectionServiceGetConnectionsKey,
...(queryKey ?? [{ limit, offset, orderBy }]),
];
export type VariableServiceGetVariableDefaultResponse = Awaited<
ReturnType<typeof VariableService.getVariable>
>;
Expand Down
30 changes: 30 additions & 0 deletions airflow/ui/openapi-gen/queries/prefetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,36 @@ export const prefetchUseConnectionServiceGetConnection = (
queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }),
queryFn: () => ConnectionService.getConnection({ connectionId }),
});
/**
* Get Connections
* Get all connection entries.
* @param data The data for the request.
* @param data.limit
* @param data.offset
* @param data.orderBy
* @returns ConnectionCollectionResponse Successful Response
* @throws ApiError
*/
export const prefetchUseConnectionServiceGetConnections = (
queryClient: QueryClient,
{
limit,
offset,
orderBy,
}: {
limit?: number;
offset?: number;
orderBy?: string;
} = {},
) =>
queryClient.prefetchQuery({
queryKey: Common.UseConnectionServiceGetConnectionsKeyFn({
limit,
offset,
orderBy,
}),
queryFn: () => ConnectionService.getConnections({ limit, offset, orderBy }),
});
/**
* Get Variable
* Get a variable entry.
Expand Down
Loading