Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,11 @@ repos:
^airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_structure.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py$|
^airflow-core/tests/unit/cli/commands/test_task_command.py$|
^airflow-core/tests/unit/cli/commands/test_team_command.py$|
^airflow-core/tests/unit/cli/commands/test_pool_command.py$|
^airflow-core/tests/unit/cli/commands/test_connection_command.py$|
^airflow-core/tests/unit/cli/commands/test_dag_command.py$|
^airflow-core/tests/unit/cli/commands/test_rotate_fernet_key_command.py$|
^airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py$|
^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py$|
^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py$|
Expand Down
13 changes: 8 additions & 5 deletions airflow-core/tests/unit/cli/commands/test_connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from io import StringIO

import pytest
from sqlalchemy import select

from airflow.cli import cli_config, cli_parser
from airflow.cli.commands import connection_command
Expand Down Expand Up @@ -596,7 +597,7 @@ def test_cli_connection_add(self, cmd, expected_output, expected_conn, session,
"schema",
"extra",
]
current_conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
current_conn = session.scalar(select(Connection).where(Connection.conn_id == conn_id))
assert expected_conn == {attr: getattr(current_conn, attr) for attr in comparable_attrs}

def test_cli_connections_add_duplicate(self):
Expand Down Expand Up @@ -704,7 +705,7 @@ def test_cli_delete_connections(self, session, stdout_capture):
assert "Successfully deleted connection with `conn_id`=new1" in stdout.getvalue()

# Check deletions
result = session.query(Connection).filter(Connection.conn_id == "new1").first()
result = session.scalar(select(Connection).where(Connection.conn_id == "new1"))

assert result is None

Expand Down Expand Up @@ -793,7 +794,9 @@ def test_cli_connections_import_should_load_connections(self, mocker):
expected_imported = {k: v for k, v in expected_connections.items() if k != "new3"}

with create_session() as session:
current_conns = session.query(Connection).filter(Connection.conn_id.in_(["new0", "new1"])).all()
current_conns = session.scalars(
select(Connection).where(Connection.conn_id.in_(["new0", "new1"]))
).all()

comparable_attrs = [
"conn_id",
Expand Down Expand Up @@ -862,7 +865,7 @@ def test_cli_connections_import_should_not_overwrite_existing_connections(self,
assert "Could not import connection new3: connection already exists." in mock_print.call_args[0][0]

# Verify that the imported connections match the expected, sample connections
current_conns = session.query(Connection).all()
current_conns = session.scalars(select(Connection)).all()

comparable_attrs = [
"conn_id",
Expand Down Expand Up @@ -933,7 +936,7 @@ def test_cli_connections_import_should_overwrite_existing_connections(self, mock
"Could not import connection new3: connection already exists." not in mock_print.call_args[0][0]
)
# Verify that the imported connections match the expected, sample connections
current_conns = session.query(Connection).all()
current_conns = session.scalars(select(Connection)).all()
comparable_attrs = [
"conn_id",
"conn_type",
Expand Down
16 changes: 9 additions & 7 deletions airflow-core/tests/unit/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import pendulum
import pytest
import time_machine
from sqlalchemy import select
from sqlalchemy import func, select

from airflow import settings
from airflow._shared.timezones import timezone
Expand Down Expand Up @@ -513,7 +513,7 @@ def test_trigger_dag(self):
),
)
with create_session() as session:
dagrun = session.query(DagRun).filter(DagRun.run_id == "test_trigger_dag").one()
dagrun = session.scalars(select(DagRun).where(DagRun.run_id == "test_trigger_dag")).one()

assert dagrun, "DagRun not created"
assert dagrun.run_type == DagRunType.MANUAL
Expand Down Expand Up @@ -541,7 +541,9 @@ def test_trigger_dag_with_microseconds(self):
)

with create_session() as session:
dagrun = session.query(DagRun).filter(DagRun.run_id == "test_trigger_dag_with_micro").one()
dagrun = session.scalars(
select(DagRun).where(DagRun.run_id == "test_trigger_dag_with_micro")
).one()

assert dagrun, "DagRun not created"
assert dagrun.run_type == DagRunType.MANUAL
Expand Down Expand Up @@ -594,7 +596,7 @@ def test_delete_dag(self):
session.add(DM(dag_id=key, bundle_name="dags-folder"))
session.commit()
dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"]))
assert session.query(DM).filter_by(dag_id=key).count() == 0
assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0
with pytest.raises(AirflowException):
dag_command.dag_delete(
self.parser.parse_args(["dags", "delete", "does_not_exist_dag", "--yes"]),
Expand Down Expand Up @@ -624,7 +626,7 @@ def test_dag_delete_when_backfill_and_dagrun_exist(self):
)
session.commit()
dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"]))
assert session.query(DM).filter_by(dag_id=key).count() == 0
assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0
with pytest.raises(AirflowException):
dag_command.dag_delete(
self.parser.parse_args(["dags", "delete", "does_not_exist_dag", "--yes"]),
Expand All @@ -640,7 +642,7 @@ def test_delete_dag_existing_file(self, tmp_path):
session.add(DM(dag_id=key, bundle_name="dags-folder", fileloc=os.fspath(path)))
session.commit()
dag_command.dag_delete(self.parser.parse_args(["dags", "delete", key, "--yes"]))
assert session.query(DM).filter_by(dag_id=key).count() == 0
assert session.scalar(select(func.count()).select_from(DM).where(DM.dag_id == key)) == 0

def test_cli_list_jobs(self):
args = self.parser.parse_args(["dags", "list-jobs"])
Expand Down Expand Up @@ -995,7 +997,7 @@ def test_reserialize(self, configure_dag_bundles, session):
assert serialized_dag_ids == {"test_example_bash_operator", "test_dag_with_no_tags", "test_sensor"}

example_bash_op = session.execute(
select(DagModel).filter(DagModel.dag_id == "test_example_bash_operator")
select(DagModel).where(DagModel.dag_id == "test_example_bash_operator")
).scalar()
assert example_bash_op.relative_fileloc == "." # the file _is_ the bundle path
assert example_bash_op.fileloc == str(TEST_DAGS_FOLDER / "test_example_bash_operator.py")
Expand Down
15 changes: 8 additions & 7 deletions airflow-core/tests/unit/cli/commands/test_pool_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json

import pytest
from sqlalchemy import delete, func, select

from airflow import models, settings
from airflow.cli import cli_parser
Expand Down Expand Up @@ -47,7 +48,7 @@ def tearDown(self):
def _cleanup(session=None):
if session is None:
session = Session()
session.query(Pool).filter(Pool.pool != Pool.DEFAULT_POOL_NAME).delete()
session.execute(delete(Pool).where(Pool.pool != Pool.DEFAULT_POOL_NAME))
session.commit()
add_default_pool_if_not_exists()
session.close()
Expand All @@ -64,19 +65,19 @@ def test_pool_list_with_args(self):

def test_pool_create(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
assert self.session.query(Pool).count() == 2
assert self.session.scalar(select(func.count()).select_from(Pool)) == 2

def test_pool_update_deferred(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False
assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False

pool_command.pool_set(
self.parser.parse_args(["pools", "set", "foo", "1", "test", "--include-deferred"])
)
assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is True
assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is True

pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False
assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False

def test_pool_get(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
Expand All @@ -85,7 +86,7 @@ def test_pool_get(self):
def test_pool_delete(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
pool_command.pool_delete(self.parser.parse_args(["pools", "delete", "foo"]))
assert self.session.query(Pool).count() == 1
assert self.session.scalar(select(func.count()).select_from(Pool)) == 1

def test_pool_import_nonexistent(self):
with pytest.raises(SystemExit):
Expand Down Expand Up @@ -123,7 +124,7 @@ def test_pool_import_backwards_compatibility(self, tmp_path):

pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)]))

assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False
assert self.session.scalar(select(Pool).where(Pool.pool == "foo")).include_deferred is False

def test_pool_import_export(self, tmp_path):
pool_import_file_path = tmp_path / "pools_import.json"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pytest
from cryptography.fernet import Fernet
from sqlalchemy import select

from airflow.cli import cli_parser
from airflow.cli.commands import rotate_fernet_key_command
Expand Down Expand Up @@ -71,7 +72,7 @@ def test_should_rotate_variable(self, session):
# Assert correctness using a new fernet key
with conf_vars({("core", "fernet_key"): fernet_key2.decode()}):
get_fernet.cache_clear() # Clear cached fernet
var1 = session.query(Variable).filter(Variable.key == var1_key).first()
var1 = session.scalar(select(Variable).where(Variable.key == var1_key))
# Unencrypted variable should be unchanged
assert Variable.get(key=var1_key) == "value"
assert var1._val == "value"
Expand Down Expand Up @@ -103,7 +104,7 @@ def test_should_rotate_connection(self, session, mock_supervisor_comms):
rotate_fernet_key_command.rotate_fernet_key(args)

def mock_get_connection(conn_id):
conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
conn = session.scalar(select(Connection).where(Connection.conn_id == conn_id))
if conn:
from airflow.sdk.execution_time.comms import ConnectionResult

Expand Down
29 changes: 15 additions & 14 deletions airflow-core/tests/unit/cli/commands/test_team_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import patch

import pytest
from sqlalchemy import select

from airflow import models, settings
from airflow.cli import cli_parser
Expand Down Expand Up @@ -67,7 +68,7 @@ def test_team_create_success(self, stdout_capture):
team_command.team_create(self.parser.parse_args(["teams", "create", "test-team"]))

# Verify team was created in database
team = self.session.query(Team).filter(Team.name == "test-team").first()
team = self.session.scalar(select(Team).where(Team.name == "test-team"))
assert team is not None
assert team.name == "test-team"

Expand Down Expand Up @@ -139,15 +140,15 @@ def test_team_delete_success(self, stdout_capture):
team_command.team_create(self.parser.parse_args(["teams", "create", "delete-me"]))

# Verify team exists
team = self.session.query(Team).filter(Team.name == "delete-me").first()
team = self.session.scalar(select(Team).where(Team.name == "delete-me"))
assert team is not None

# Delete team with --yes flag
with stdout_capture as stdout:
team_command.team_delete(self.parser.parse_args(["teams", "delete", "delete-me", "--yes"]))

# Verify team was deleted
team = self.session.query(Team).filter(Team.name == "delete-me").first()
team = self.session.scalar(select(Team).where(Team.name == "delete-me"))
assert team is None

# Verify output message
Expand All @@ -168,7 +169,7 @@ def test_team_delete_with_dag_bundle_association(self):
"""Test deleting team that has DAG bundle associations."""
# Create team
team_command.team_create(self.parser.parse_args(["teams", "create", "bundle-team"]))
team = self.session.query(Team).filter(Team.name == "bundle-team").first()
team = self.session.scalar(select(Team).where(Team.name == "bundle-team"))

# Create a DAG bundle first
dag_bundle = DagBundleModel(name="test-bundle")
Expand All @@ -194,7 +195,7 @@ def test_team_delete_with_connection_association(self):
"""Test deleting team that has connection associations."""
# Create team
team_command.team_create(self.parser.parse_args(["teams", "create", "conn-team"]))
team = self.session.query(Team).filter(Team.name == "conn-team").first()
team = self.session.scalar(select(Team).where(Team.name == "conn-team"))

# Create connection associated with team
conn = Connection(conn_id="test-conn", conn_type="http", team_name=team.name)
Expand All @@ -212,7 +213,7 @@ def test_team_delete_with_variable_association(self):
"""Test deleting team that has variable associations."""
# Create team
team_command.team_create(self.parser.parse_args(["teams", "create", "var-team"]))
team = self.session.query(Team).filter(Team.name == "var-team").first()
team = self.session.scalar(select(Team).where(Team.name == "var-team"))

# Create variable associated with team
var = Variable(key="test-var", val="test-value", team_name=team.name)
Expand All @@ -229,7 +230,7 @@ def test_team_delete_with_pool_association(self):
"""Test deleting team that has pool associations."""
# Create team
team_command.team_create(self.parser.parse_args(["teams", "create", "pool-team"]))
team = self.session.query(Team).filter(Team.name == "pool-team").first()
team = self.session.scalar(select(Team).where(Team.name == "pool-team"))

# Create pool associated with team
pool = Pool(
Expand All @@ -248,7 +249,7 @@ def test_team_delete_with_multiple_associations(self):
"""Test deleting team that has multiple types of associations."""
# Create team
team_command.team_create(self.parser.parse_args(["teams", "create", "multi-team"]))
team = self.session.query(Team).filter(Team.name == "multi-team").first()
team = self.session.scalar(select(Team).where(Team.name == "multi-team"))

# Create a DAG bundle first
dag_bundle = DagBundleModel(name="multi-bundle")
Expand Down Expand Up @@ -292,7 +293,7 @@ def test_team_delete_with_confirmation_yes(self, mock_input, stdout_capture):
team_command.team_delete(self.parser.parse_args(["teams", "delete", "confirm-yes"]))

# Verify team was deleted
team = self.session.query(Team).filter(Team.name == "confirm-yes").first()
team = self.session.scalar(select(Team).where(Team.name == "confirm-yes"))
assert team is None

output = stdout.getvalue()
Expand All @@ -309,7 +310,7 @@ def test_team_delete_with_confirmation_no(self, mock_input, stdout_capture):
team_command.team_delete(self.parser.parse_args(["teams", "delete", "confirm-no"]))

# Verify team was NOT deleted
team = self.session.query(Team).filter(Team.name == "confirm-no").first()
team = self.session.scalar(select(Team).where(Team.name == "confirm-no"))
assert team is not None

output = stdout.getvalue()
Expand All @@ -326,7 +327,7 @@ def test_team_delete_with_confirmation_invalid(self, mock_input, stdout_capture)
team_command.team_delete(self.parser.parse_args(["teams", "delete", "confirm-invalid"]))

# Verify team was NOT deleted (invalid input treated as No)
team = self.session.query(Team).filter(Team.name == "confirm-invalid").first()
team = self.session.scalar(select(Team).where(Team.name == "confirm-invalid"))
assert team is not None

output = stdout.getvalue()
Expand All @@ -335,7 +336,7 @@ def test_team_delete_with_confirmation_invalid(self, mock_input, stdout_capture)
def test_team_operations_integration(self):
"""Test integration of create, list, and delete operations."""
# Start with empty state
teams = self.session.query(Team).all()
teams = self.session.scalars(select(Team)).all()
assert len(teams) == 0

# Create multiple teams
Expand All @@ -344,7 +345,7 @@ def test_team_operations_integration(self):
team_command.team_create(self.parser.parse_args(["teams", "create", "integration-3"]))

# Verify all teams exist
teams = self.session.query(Team).all()
teams = self.session.scalars(select(Team)).all()
assert len(teams) == 3
team_names = [team.name for team in teams]
assert "integration-1" in team_names
Expand All @@ -355,7 +356,7 @@ def test_team_operations_integration(self):
team_command.team_delete(self.parser.parse_args(["teams", "delete", "integration-2", "--yes"]))

# Verify correct team was deleted
teams = self.session.query(Team).all()
teams = self.session.scalars(select(Team)).all()
assert len(teams) == 2
team_names = [team.name for team in teams]
assert "integration-1" in team_names
Expand Down