Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@

import dill

from airflow.exceptions import AirflowConfigException, AirflowException, RemovedInAirflow3Warning
from airflow.exceptions import (
AirflowConfigException,
AirflowException,
AirflowSkipException,
RemovedInAirflow3Warning,
)
from airflow.models.baseoperator import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import _CURRENT_CONTEXT
Expand Down Expand Up @@ -466,6 +471,9 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
:param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator
will raise warning if Airflow is not installed, and it will attempt to load Airflow
macros when starting.
:param skip_exit_code: If python_callable exits with this exit code, leave the task
in ``skipped`` state (default: None). If set to ``None``, any non-zero
exit code will be treated as a failure.
"""

template_fields: Sequence[str] = tuple({"requirements"} | set(PythonOperator.template_fields))
Expand All @@ -486,6 +494,7 @@ def __init__(
templates_dict: dict | None = None,
templates_exts: list[str] | None = None,
expect_airflow: bool = True,
skip_exit_code: int | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m late but this name feels like it’s skipping the exit code; would it be better to name it e.g. skip_on_exit_code instead? Also would it be a good idea to allow multiple codes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good suggestions, I'll create a new PR to address them

**kwargs,
):
if (
Expand All @@ -509,6 +518,7 @@ def __init__(
self.python_version = python_version
self.system_site_packages = system_site_packages
self.pip_install_options = pip_install_options
self.skip_exit_code = skip_exit_code
super().__init__(
python_callable=python_callable,
use_dill=use_dill,
Expand Down Expand Up @@ -544,8 +554,14 @@ def execute_callable(self):
pip_install_options=self.pip_install_options,
)
python_path = tmp_path / "bin" / "python"

return self._execute_python_callable_in_subprocess(python_path, tmp_path)
try:
result = self._execute_python_callable_in_subprocess(python_path, tmp_path)
except subprocess.CalledProcessError as e:
if self.skip_exit_code and e.returncode == self.skip_exit_code:
raise AirflowSkipException(f"Process exited with code {self.skip_exit_code}. Skipping.")
else:
raise
return result

def _iter_serializable_context_keys(self):
yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
Expand Down
34 changes: 32 additions & 2 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from airflow.utils.context import AirflowContextDeprecationWarning, Context
from airflow.utils.python_virtualenv import prepare_virtualenv
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState, State
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET, DagRunType
from tests.test_utils import AIRFLOW_MAIN_FOLDER
Expand Down Expand Up @@ -131,10 +131,12 @@ def run_as_operator(self, fn, **kwargs):
task.run(start_date=self.default_date, end_date=self.default_date)
return task

def run_as_task(self, fn, **kwargs):
def run_as_task(self, fn, return_ti=False, **kwargs):
"""Create TaskInstance and run it."""
ti = self.create_ti(fn, **kwargs)
ti.run()
if return_ti:
return ti
return ti.task

def render_templates(self, fn, **kwargs):
Expand Down Expand Up @@ -932,6 +934,34 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance):
}
assert set(context) == declared_keys

@pytest.mark.parametrize(
"extra_kwargs, actual_exit_code, expected_state",
[
(None, 99, TaskInstanceState.FAILED),
({"skip_exit_code": 100}, 100, TaskInstanceState.SKIPPED),
({"skip_exit_code": 100}, 101, TaskInstanceState.FAILED),
({"skip_exit_code": None}, 0, TaskInstanceState.SUCCESS),
],
)
def test_skip_exit_code(self, extra_kwargs, actual_exit_code, expected_state):
def f(exit_code):
if exit_code != 0:
raise SystemExit(exit_code)

if expected_state == TaskInstanceState.FAILED:
with pytest.raises(CalledProcessError):
self.run_as_task(
f, op_kwargs={"exit_code": actual_exit_code}, **(extra_kwargs if extra_kwargs else {})
)
else:
ti = self.run_as_task(
f,
return_ti=True,
op_kwargs={"exit_code": actual_exit_code},
**(extra_kwargs if extra_kwargs else {}),
)
assert ti.state == expected_state


class TestCurrentContext:
def test_current_context_no_context_raise(self):
Expand Down