diff --git a/providers/src/airflow/providers/celery/executors/celery_executor.py b/providers/src/airflow/providers/celery/executors/celery_executor.py index 5632f94e22cb1..9cd8d1550cde9 100644 --- a/providers/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/src/airflow/providers/celery/executors/celery_executor.py @@ -160,15 +160,12 @@ def __getattr__(name): action="store_true", ) -CELERY_CLI_COMMAND_PATH = ( - "airflow.providers.celery.cli.celery_command" - if AIRFLOW_V_2_8_PLUS - else ( - "airflow.cli.commands.local_commands.celery_command" - if AIRFLOW_V_3_0_PLUS - else "airflow.cli.commands.celery_command" - ) -) +if AIRFLOW_V_2_8_PLUS: + CELERY_CLI_COMMAND_PATH = "airflow.providers.celery.cli.celery_command" +elif AIRFLOW_V_3_0_PLUS: + CELERY_CLI_COMMAND_PATH = "airflow.cli.commands.local_commands.celery_command" +else: + CELERY_CLI_COMMAND_PATH = "airflow.cli.commands.celery_command" CELERY_COMMANDS = ( ActionCommand( diff --git a/providers/src/airflow/providers/fab/auth_manager/cli_commands/db_command.py b/providers/src/airflow/providers/fab/auth_manager/cli_commands/db_command.py index 81b8b378802b0..ad3127c9659f7 100644 --- a/providers/src/airflow/providers/fab/auth_manager/cli_commands/db_command.py +++ b/providers/src/airflow/providers/fab/auth_manager/cli_commands/db_command.py @@ -24,15 +24,11 @@ def get_db_command(): - try: - if AIRFLOW_V_3_0_PLUS: - import airflow.cli.commands.local_commands.db_command as db_command - else: - import airflow.cli.commands.db_command as db_command - except ImportError: - from airflow.exceptions import AirflowOptionalProviderFeatureException - - raise AirflowOptionalProviderFeatureException("Failed to import db_command from Airflow CLI.") + """Import the correct db_command module based on the Airflow version.""" + if AIRFLOW_V_3_0_PLUS: + import airflow.cli.commands.local_commands.db_command as db_command + else: + import airflow.cli.commands.db_command as db_command return db_command diff --git a/providers/tests/celery/cli/test_celery_command.py b/providers/tests/celery/cli/test_celery_command.py index c6bd216aced6e..bed8866838670 100644 --- a/providers/tests/celery/cli/test_celery_command.py +++ b/providers/tests/celery/cli/test_celery_command.py @@ -270,6 +270,97 @@ def test_run_command(self, mock_celery_app): ] ) + def _test_run_command_daemon(self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file): + mock_setup_locations.return_value = ( + mock.MagicMock(name="pidfile"), + mock.MagicMock(name="stdout"), + mock.MagicMock(name="stderr"), + mock.MagicMock(name="INVALID"), + ) + args = self.parser.parse_args( + [ + "celery", + "flower", + "--basic-auth", + "admin:admin", + "--broker-api", + "http://username:password@rabbitmq-server-name:15672/api/", + "--flower-conf", + "flower_config", + "--hostname", + "my-hostname", + "--log-file", + "/tmp/flower.log", + "--pid", + "/tmp/flower.pid", + "--port", + "3333", + "--stderr", + "/tmp/flower-stderr.log", + "--stdout", + "/tmp/flower-stdout.log", + "--url-prefix", + "flower-monitoring", + "--daemon", + ] + ) + mock_open = mock.mock_open() + with mock.patch("airflow.cli.commands.local_commands.daemon_utils.open", mock_open): + celery_command.flower(args) + + mock_celery_app.start.assert_called_once_with( + [ + "flower", + conf.get("celery", "BROKER_URL"), + "--address=my-hostname", + "--port=3333", + "--broker-api=http://username:password@rabbitmq-server-name:15672/api/", + "--url-prefix=flower-monitoring", + "--basic-auth=admin:admin", + "--conf=flower_config", + ] + ) + assert mock_daemon.mock_calls[:3] == [ + mock.call.DaemonContext( + pidfile=mock_pid_file.return_value, + files_preserve=None, + stdout=mock_open.return_value, + stderr=mock_open.return_value, + umask=0o077, + ), + mock.call.DaemonContext().__enter__(), + mock.call.DaemonContext().__exit__(None, None, None), + ] + + assert mock_setup_locations.mock_calls == [ + mock.call( + process="flower", + pid="/tmp/flower.pid", + stdout="/tmp/flower-stdout.log", + stderr="/tmp/flower-stderr.log", + log="/tmp/flower.log", + ) + if AIRFLOW_V_2_10_PLUS + else mock.call( + process="flower", + stdout="/tmp/flower-stdout.log", + stderr="/tmp/flower-stderr.log", + log="/tmp/flower.log", + ) + ] + mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0], -1)]) + assert mock_open.mock_calls == [ + mock.call(mock_setup_locations.return_value[1], "a"), + mock.call().__enter__(), + mock.call(mock_setup_locations.return_value[2], "a"), + mock.call().__enter__(), + mock.call().truncate(0), + mock.call().truncate(0), + mock.call().__exit__(None, None, None), + mock.call().__exit__(None, None, None), + ] + + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0-") @mock.patch("airflow.cli.commands.daemon_utils.TimeoutPIDLockFile") @mock.patch("airflow.cli.commands.daemon_utils.setup_locations") @mock.patch("airflow.cli.commands.daemon_utils.daemon") @@ -277,96 +368,9 @@ def test_run_command(self, mock_celery_app): def test_run_command_daemon_v_3_below( self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file ): - if not AIRFLOW_V_3_0_PLUS: - mock_setup_locations.return_value = ( - mock.MagicMock(name="pidfile"), - mock.MagicMock(name="stdout"), - mock.MagicMock(name="stderr"), - mock.MagicMock(name="INVALID"), - ) - args = self.parser.parse_args( - [ - "celery", - "flower", - "--basic-auth", - "admin:admin", - "--broker-api", - "http://username:password@rabbitmq-server-name:15672/api/", - "--flower-conf", - "flower_config", - "--hostname", - "my-hostname", - "--log-file", - "/tmp/flower.log", - "--pid", - "/tmp/flower.pid", - "--port", - "3333", - "--stderr", - "/tmp/flower-stderr.log", - "--stdout", - "/tmp/flower-stdout.log", - "--url-prefix", - "flower-monitoring", - "--daemon", - ] - ) - mock_open = mock.mock_open() - with mock.patch("airflow.cli.commands.local_commands.daemon_utils.open", mock_open): - celery_command.flower(args) - - mock_celery_app.start.assert_called_once_with( - [ - "flower", - conf.get("celery", "BROKER_URL"), - "--address=my-hostname", - "--port=3333", - "--broker-api=http://username:password@rabbitmq-server-name:15672/api/", - "--url-prefix=flower-monitoring", - "--basic-auth=admin:admin", - "--conf=flower_config", - ] - ) - assert mock_daemon.mock_calls[:3] == [ - mock.call.DaemonContext( - pidfile=mock_pid_file.return_value, - files_preserve=None, - stdout=mock_open.return_value, - stderr=mock_open.return_value, - umask=0o077, - ), - mock.call.DaemonContext().__enter__(), - mock.call.DaemonContext().__exit__(None, None, None), - ] - - assert mock_setup_locations.mock_calls == [ - mock.call( - process="flower", - pid="/tmp/flower.pid", - stdout="/tmp/flower-stdout.log", - stderr="/tmp/flower-stderr.log", - log="/tmp/flower.log", - ) - if AIRFLOW_V_2_10_PLUS - else mock.call( - process="flower", - stdout="/tmp/flower-stdout.log", - stderr="/tmp/flower-stderr.log", - log="/tmp/flower.log", - ) - ] - mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0], -1)]) - assert mock_open.mock_calls == [ - mock.call(mock_setup_locations.return_value[1], "a"), - mock.call().__enter__(), - mock.call(mock_setup_locations.return_value[2], "a"), - mock.call().__enter__(), - mock.call().truncate(0), - mock.call().truncate(0), - mock.call().__exit__(None, None, None), - mock.call().__exit__(None, None, None), - ] + self._test_run_command_daemon(mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file) + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+") @mock.patch("airflow.cli.commands.local_commands.daemon_utils.TimeoutPIDLockFile") @mock.patch("airflow.cli.commands.local_commands.daemon_utils.setup_locations") @mock.patch("airflow.cli.commands.local_commands.daemon_utils.daemon") @@ -374,92 +378,4 @@ def test_run_command_daemon_v_3_below( def test_run_command_daemon_v3_above( self, mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file ): - if AIRFLOW_V_3_0_PLUS: - mock_setup_locations.return_value = ( - mock.MagicMock(name="pidfile"), - mock.MagicMock(name="stdout"), - mock.MagicMock(name="stderr"), - mock.MagicMock(name="INVALID"), - ) - args = self.parser.parse_args( - [ - "celery", - "flower", - "--basic-auth", - "admin:admin", - "--broker-api", - "http://username:password@rabbitmq-server-name:15672/api/", - "--flower-conf", - "flower_config", - "--hostname", - "my-hostname", - "--log-file", - "/tmp/flower.log", - "--pid", - "/tmp/flower.pid", - "--port", - "3333", - "--stderr", - "/tmp/flower-stderr.log", - "--stdout", - "/tmp/flower-stdout.log", - "--url-prefix", - "flower-monitoring", - "--daemon", - ] - ) - mock_open = mock.mock_open() - with mock.patch("airflow.cli.commands.local_commands.daemon_utils.open", mock_open): - celery_command.flower(args) - - mock_celery_app.start.assert_called_once_with( - [ - "flower", - conf.get("celery", "BROKER_URL"), - "--address=my-hostname", - "--port=3333", - "--broker-api=http://username:password@rabbitmq-server-name:15672/api/", - "--url-prefix=flower-monitoring", - "--basic-auth=admin:admin", - "--conf=flower_config", - ] - ) - assert mock_daemon.mock_calls[:3] == [ - mock.call.DaemonContext( - pidfile=mock_pid_file.return_value, - files_preserve=None, - stdout=mock_open.return_value, - stderr=mock_open.return_value, - umask=0o077, - ), - mock.call.DaemonContext().__enter__(), - mock.call.DaemonContext().__exit__(None, None, None), - ] - - assert mock_setup_locations.mock_calls == [ - mock.call( - process="flower", - pid="/tmp/flower.pid", - stdout="/tmp/flower-stdout.log", - stderr="/tmp/flower-stderr.log", - log="/tmp/flower.log", - ) - if AIRFLOW_V_2_10_PLUS - else mock.call( - process="flower", - stdout="/tmp/flower-stdout.log", - stderr="/tmp/flower-stderr.log", - log="/tmp/flower.log", - ) - ] - mock_pid_file.assert_has_calls([mock.call(mock_setup_locations.return_value[0], -1)]) - assert mock_open.mock_calls == [ - mock.call(mock_setup_locations.return_value[1], "a"), - mock.call().__enter__(), - mock.call(mock_setup_locations.return_value[2], "a"), - mock.call().__enter__(), - mock.call().truncate(0), - mock.call().truncate(0), - mock.call().__exit__(None, None, None), - mock.call().__exit__(None, None, None), - ] + self._test_run_command_daemon(mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file)