Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_connection
from airflow.api_fastapi.core_api.services.public.connections import BulkConnectionService
from airflow.api_fastapi.core_api.services.public.connections import (
BulkConnectionService,
update_orm_from_pydantic,
)
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.configuration import conf
from airflow.models import Connection
Expand Down Expand Up @@ -187,29 +190,19 @@ def patch_connection(
"The connection_id in the request body does not match the URL parameter",
)

non_update_fields = {"connection_id", "conn_id"}
connection = session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
connection: Connection = session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))

if connection is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND, f"The Connection with connection_id: `{connection_id}` was not found"
)

fields_to_update = patch_body.model_fields_set

if update_mask:
fields_to_update = fields_to_update.intersection(update_mask)
else:
try:
ConnectionBody(**patch_body.model_dump())
except ValidationError as e:
raise RequestValidationError(errors=e.errors())

data = patch_body.model_dump(include=fields_to_update - non_update_fields, by_alias=True)

for key, val in data.items():
setattr(connection, key, val)
try:
ConnectionBody(**patch_body.model_dump())
except ValidationError as e:
raise RequestValidationError(errors=e.errors())

update_orm_from_pydantic(connection, patch_body, update_mask)
return connection


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,36 @@
from airflow.models.connection import Connection


def update_orm_from_pydantic(
orm_conn: Connection, pydantic_conn: ConnectionBody, update_mask: list[str] | None = None
) -> None:
"""Update ORM object from Pydantic object."""
# Not all fields match and some need setters, therefore copy partly manually via setters
non_update_fields = {"connection_id", "conn_id"}
setter_fields = {"password", "extra"}
fields_set = pydantic_conn.model_fields_set
if "schema_" in fields_set: # Alias is not resolved correctly, need to patch
fields_set.remove("schema_")
fields_set.add("schema")
fields_to_update = fields_set - non_update_fields - setter_fields
if update_mask:
fields_to_update = fields_to_update.intersection(update_mask)
print(fields_to_update)
conn_data = pydantic_conn.model_dump(by_alias=True)
for key, val in conn_data.items():
if key in fields_to_update:
setattr(orm_conn, key, val)

if (not update_mask and "password" in pydantic_conn.model_fields_set) or (
update_mask and "password" in update_mask
):
orm_conn.set_password(pydantic_conn.password)
if (not update_mask and "extra" in pydantic_conn.model_fields_set) or (
update_mask and "extra" in update_mask
):
orm_conn.set_extra(pydantic_conn.extra)


class BulkConnectionService(BulkService[ConnectionBody]):
"""Service for handling bulk operations on connections."""

Expand Down Expand Up @@ -108,12 +138,16 @@ def handle_bulk_update(

for connection in action.entities:
if connection.connection_id in update_connection_ids:
old_connection = self.session.scalar(
old_connection: Connection = self.session.scalar(
select(Connection).filter(Connection.conn_id == connection.connection_id).limit(1)
)
if old_connection is None:
raise ValidationError(
f"The Connection with connection_id: `{connection.connection_id}` was not found"
)
ConnectionBody(**connection.model_dump())
for key, val in connection.model_dump(by_alias=True).items():
setattr(old_connection, key, val)

update_orm_from_pydantic(old_connection, connection)
results.success.append(connection.connection_id)

except HTTPException as e:
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def get_extra(self) -> str:
self._validate_extra(extra_val, self.conn_id)
return extra_val

def set_extra(self, value: str):
def set_extra(self, value: str | None):
"""Encrypt extra-data and save in object attribute to object."""
if value:
self._validate_extra(value, self.conn_id)
Expand Down
Loading