Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.configuration import conf
from airflow.exceptions import AirflowNotFoundException
from airflow.models import Connection
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
from airflow.utils.db import create_default_connections as db_create_default_connections
Expand Down Expand Up @@ -207,9 +208,7 @@ def patch_connection(


@connections_router.post("/test", dependencies=[Depends(requires_access_connection(method="POST"))])
def test_connection(
test_body: ConnectionBody,
) -> ConnectionTestResponse:
def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse:
"""
Test an API connection.

Expand All @@ -227,9 +226,17 @@ def test_connection(
transient_conn_id = get_random_string()
conn_env_var = f"{CONN_ENV_PREFIX}{transient_conn_id.upper()}"
try:
data = test_body.model_dump(by_alias=True)
data["conn_id"] = transient_conn_id
conn = Connection(**data)
# Try to get existing connection and merge with provided values
try:
existing_conn = Connection.get_connection_from_secrets(test_body.connection_id)
existing_conn.conn_id = transient_conn_id
update_orm_from_pydantic(existing_conn, test_body)
conn = existing_conn
except AirflowNotFoundException:
data = test_body.model_dump(by_alias=True)
data["conn_id"] = transient_conn_id
conn = Connection(**data)

os.environ[conn_env_var] = conn.get_uri()
test_status, test_message = conn.test_connection()
return ConnectionTestResponse.model_validate({"status": test_status, "message": test_message})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
# under the License.
from __future__ import annotations

import json
import os
from importlib.metadata import PackageNotFoundError, metadata
from unittest import mock

import pytest
from sqlalchemy import func, select

from airflow.models import Connection
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
Expand Down Expand Up @@ -920,6 +922,140 @@ def test_should_respond_403_by_default(self, test_client, body):
"Contact your deployment admin to enable it."
}

@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
def test_should_merge_password_with_existing_connection(self, test_client, session):
connection = Connection(
conn_id=TEST_CONN_ID,
conn_type="sqlite",
password="existing_password",
)
session.add(connection)
session.commit()
initial_count = session.scalar(select(func.count()).select_from(Connection))

captured_value = {}

def mock_test_connection(self):
captured_value["password"] = self.password
captured_value["conn_type"] = self.conn_type
return True, "mocked"

body = {
"connection_id": TEST_CONN_ID,
"conn_type": "new_sqlite",
"password": "***",
}

with mock.patch.object(Connection, "test_connection", mock_test_connection):
response = test_client.post("/connections/test", json=body)

assert response.status_code == 200
assert response.json()["status"] is True
# Verify that the existing password was used, not "***"
assert captured_value["password"] == "existing_password"
# Verify that payload info were used for other fields
assert captured_value["conn_type"] == "new_sqlite"

# Verify DB was not mutated
session.expire_all()
db_conn = session.scalar(select(Connection).filter_by(conn_id=TEST_CONN_ID))
assert db_conn.password == "existing_password"
assert session.scalar(select(func.count()).select_from(Connection)) == initial_count

@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
def test_should_merge_extra_with_existing_connection(self, test_client, session):
connection = Connection(
conn_id=TEST_CONN_ID,
conn_type="fs",
extra='{"path": "/", "existing_key": "existing_value"}',
)
session.add(connection)
session.commit()
initial_count = session.scalar(select(func.count()).select_from(Connection))

captured_extra = {}

def mock_test_connection(self):
captured_extra["value"] = self.extra
return True, "mocked"

body = {
"connection_id": TEST_CONN_ID,
"conn_type": "fs",
"extra": '{"path": "/", "new_key": "new_value"}',
}

with mock.patch.object(Connection, "test_connection", mock_test_connection):
response = test_client.post("/connections/test", json=body)

assert response.status_code == 200
assert response.json()["status"] is True
# Verify that new_key is reflected in the merged extra
merged_extra = json.loads(captured_extra["value"])
assert merged_extra["new_key"] == "new_value"
assert merged_extra["path"] == "/"

# Verify DB was not mutated
session.expire_all()
db_conn = session.scalar(select(Connection).filter_by(conn_id=TEST_CONN_ID))
assert json.loads(db_conn.extra) == {"path": "/", "existing_key": "existing_value"}
assert session.scalar(select(func.count()).select_from(Connection)) == initial_count

@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
def test_should_merge_both_password_and_extra(self, test_client, session):
connection = Connection(
conn_id=TEST_CONN_ID,
conn_type="fs",
password="existing_password",
extra='{"path": "/", "existing_key": "existing_value"}',
)
session.add(connection)
session.commit()
initial_count = session.scalar(select(func.count()).select_from(Connection))

captured_values = {}

def mock_test_connection(self):
captured_values["password"] = self.password
captured_values["extra"] = self.extra
return True, "mocked"

body = {
"connection_id": TEST_CONN_ID,
"conn_type": "fs",
"password": "***",
"extra": '{"path": "/", "new_key": "new_value"}',
}

with mock.patch.object(Connection, "test_connection", mock_test_connection):
response = test_client.post("/connections/test", json=body)

assert response.status_code == 200
assert response.json()["status"] is True
# Verify that the existing password was used, not "***"
assert captured_values["password"] == "existing_password"
# Verify that new_key is reflected in the merged extra
merged_extra = json.loads(captured_values["extra"])
assert merged_extra["new_key"] == "new_value"
assert merged_extra["path"] == "/"

# Verify DB was not mutated
session.expire_all()
db_conn = session.scalar(select(Connection).filter_by(conn_id=TEST_CONN_ID))
assert db_conn.password == "existing_password"
assert json.loads(db_conn.extra) == {"path": "/", "existing_key": "existing_value"}
assert session.scalar(select(func.count()).select_from(Connection)) == initial_count

@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
def test_should_test_new_connection_without_existing(self, test_client):
body = {
"connection_id": "non_existent_conn",
"conn_type": "sqlite",
}
response = test_client.post("/connections/test", json=body)
assert response.status_code == 200
assert response.json()["status"] is True


class TestCreateDefaultConnections(TestConnectionEndpoint):
def test_should_respond_204(self, test_client, session):
Expand Down